Conv horiz: expand hwy avx512 Further optimize for small blocks. But fall back to avx2 for blocks with height 32. - Use 8-bit pairwise multiply-accumulate (SatWidenMulPairwiseAdd) instead of 16-bit math for w <= 32 with even coefficients. - Halve filter coefficients to fit in int8_t and avoid overflow, adjusting final scaling shift to FILTER_BITS - 1. - Eliminate expensive 8-bit to 16-bit pixel promotion. - Add specialized unrolled loops for w = 4, 8, 16, and 32. All blocks now show significant speed up except for small slow downs for block 16x32, 32x32 and 64x32 I'll further investigate these block sizes. Size | avx2 | avx512 (diff) ------------------------------------------ 4x4 | 5.62µs | 4.03µs (-28.3%) 4x8 | 6.78µs | 5.17µs (-23.7%) 8x4 | 5.94µs | 4.03µs (-32.2%) 8x8 | 6.75µs | 5.17µs (-23.4%) 8x16 | 10.01µs | 7.66µs (-23.4%) 16x8 | 7.28µs | 6.49µs (-10.8%) 16x16 | 10.92µs | 10.47µs (-4.1%) 16x32 | 17.94µs | 19.83µs (+10.5%) 32x16 | 19.34µs | 19.59µs (+1.3%) 32x32 | 33.67µs | 38.31µs (+13.8%) 32x64 | 170.90µs | 153.10µs (-10.4%) 64x32 | 68.21µs | 76.28µs (+11.8%) 64x64 | 307.20µs | 151.80µs (-50.6%) 64x128 | 677.800s | 305.30µs (-55.0%) 128x64 | 527.90µs | 298.60µs (-43.4%) 128x128 | 1.35ms | 593.90µs (-56.1%) Change-Id: I4134a9ca0e233855761f6b03c5f35e8fcf8e25fa
diff --git a/aom_dsp/convolve_hwy.h b/aom_dsp/convolve_hwy.h index e5be37e..0b28531 100644 --- a/aom_dsp/convolve_hwy.h +++ b/aom_dsp/convolve_hwy.h
@@ -134,26 +134,196 @@ HWY_ATTR inline void ConvolveHoriz2Tap(const uint8_t *src, ptrdiff_t src_stride, uint8_t *dst, ptrdiff_t dst_stride, const int16_t *filter_x, int w, int h) { - hn::ScalableTag<int16_t> mul_tag; - hn::Rebind<uint8_t, decltype(mul_tag)> pixel_tag; - auto filter_0 = hn::Set(mul_tag, filter_x[3]); - auto filter_1 = hn::Set(mul_tag, filter_x[4]); - auto vw = hn::Lanes(mul_tag); - for (int i = 0; i < h; ++i) { - for (int j = 0; j < w; j += vw) { - auto src0 = hn::PromoteTo(mul_tag, hn::LoadU(pixel_tag, &src[j])); - auto src1 = hn::PromoteTo(mul_tag, hn::LoadU(pixel_tag, &src[j + 1])); - auto mulv = hn::RoundingShiftRight<FILTER_BITS>(src0 * filter_0 + - src1 * filter_1); - auto mulv_demoted = hn::DemoteTo(pixel_tag, mulv); - if (j + static_cast<int>(vw) > w) { - hn::StoreN(mulv_demoted, pixel_tag, &dst[j], w - j); - } else { - hn::StoreU(mulv_demoted, pixel_tag, &dst[j]); + const bool can_use_optimized_path = + (w <= 32) && (filter_x[3] % 2 == 0) && (filter_x[4] % 2 == 0); + + if (can_use_optimized_path) { + hn::CappedTag<uint8_t, 16> tag8_16; + hn::CappedTag<int8_t, 16> tag_i8; + hn::CappedTag<int16_t, 8> tag16_8; + hn::CappedTag<uint8_t, 8> tag8_8; + hn::CappedTag<uint8_t, 4> tag8_4; + const auto bias_val = hn::Set(tag16_8, 1 << (FILTER_BITS - 2)); + + const auto shuffle_mask = hn::Dup128VecFromValues( + tag8_16, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8); + + const int8_t c3 = static_cast<int8_t>(filter_x[3] / 2); + const int8_t c4 = static_cast<int8_t>(filter_x[4] / 2); + + const auto coeff34 = hn::Dup128VecFromValues( + tag_i8, c3, c4, c3, c4, c3, c4, c3, c4, c3, c4, c3, c4, c3, c4, c3, c4); + + if (w == 4) { + while (h >= 2) { + auto r0_d0 = hn::LoadU(tag8_16, src + 0); + auto r1_d0 = hn::LoadU(tag8_16, src + src_stride + 0); + + auto r0_sum = hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_d0, shuffle_mask), coeff34); + auto r1_sum = hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_d0, shuffle_mask), coeff34); + + hn::StoreU( + hn::LowerHalf(tag8_4, + hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r0_sum, bias_val)))), + tag8_4, dst); + hn::StoreU( + hn::LowerHalf(tag8_4, + hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r1_sum, bias_val)))), + tag8_4, dst + dst_stride); + + src += 2 * src_stride; + dst += 2 * dst_stride; + h -= 2; + } + } else if (w == 8) { + while (h >= 2) { + auto r0_d0 = hn::LoadU(tag8_16, src + 0); + auto r1_d0 = hn::LoadU(tag8_16, src + src_stride + 0); + + auto r0_sum = hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_d0, shuffle_mask), coeff34); + auto r1_sum = hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_d0, shuffle_mask), coeff34); + + hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r0_sum, bias_val))), + tag8_8, dst); + hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r1_sum, bias_val))), + tag8_8, dst + dst_stride); + + src += 2 * src_stride; + dst += 2 * dst_stride; + h -= 2; + } + } else if (w == 16) { + while (h >= 2) { + auto r0_j0_d0 = hn::LoadU(tag8_16, src + 0); + auto r0_j8_d0 = hn::LoadU(tag8_16, src + 8); + + auto r1_j0_d0 = hn::LoadU(tag8_16, src + src_stride + 0); + auto r1_j8_d0 = hn::LoadU(tag8_16, src + src_stride + 8); + + auto r0_j0_sum = hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_j0_d0, shuffle_mask), coeff34); + auto r0_j8_sum = hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_j8_d0, shuffle_mask), coeff34); + + auto r1_j0_sum = hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_j0_d0, shuffle_mask), coeff34); + auto r1_j8_sum = hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_j8_d0, shuffle_mask), coeff34); + + hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r0_j0_sum, bias_val))), + tag8_8, dst + 0); + hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r0_j8_sum, bias_val))), + tag8_8, dst + 8); + hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r1_j0_sum, bias_val))), + tag8_8, dst + dst_stride + 0); + hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r1_j8_sum, bias_val))), + tag8_8, dst + dst_stride + 8); + + src += 2 * src_stride; + dst += 2 * dst_stride; + h -= 2; + } + } else if (w == 32) { + while (h >= 2) { + { + auto r0_j0_d0 = hn::LoadU(tag8_16, src + 0); + auto r0_j8_d0 = hn::LoadU(tag8_16, src + 8); + + auto r1_j0_d0 = hn::LoadU(tag8_16, src + src_stride + 0); + auto r1_j8_d0 = hn::LoadU(tag8_16, src + src_stride + 8); + + auto r0_j0_sum = hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_j0_d0, shuffle_mask), coeff34); + auto r0_j8_sum = hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_j8_d0, shuffle_mask), coeff34); + + auto r1_j0_sum = hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_j0_d0, shuffle_mask), coeff34); + auto r1_j8_sum = hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_j8_d0, shuffle_mask), coeff34); + + hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r0_j0_sum, bias_val))), + tag8_8, dst + 0); + hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r0_j8_sum, bias_val))), + tag8_8, dst + 8); + hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r1_j0_sum, bias_val))), + tag8_8, dst + dst_stride + 0); + hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r1_j8_sum, bias_val))), + tag8_8, dst + dst_stride + 8); + } + { + auto r0_j16_d0 = hn::LoadU(tag8_16, src + 16); + auto r0_j24_d0 = hn::LoadU(tag8_16, src + 24); + + auto r1_j16_d0 = hn::LoadU(tag8_16, src + src_stride + 16); + auto r1_j24_d0 = hn::LoadU(tag8_16, src + src_stride + 24); + + auto r0_j16_sum = hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_j16_d0, shuffle_mask), coeff34); + auto r0_j24_sum = hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_j24_d0, shuffle_mask), coeff34); + + auto r1_j16_sum = hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_j16_d0, shuffle_mask), coeff34); + auto r1_j24_sum = hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_j24_d0, shuffle_mask), coeff34); + + hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r0_j16_sum, bias_val))), + tag8_8, dst + 16); + hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r0_j24_sum, bias_val))), + tag8_8, dst + 24); + hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r1_j16_sum, bias_val))), + tag8_8, dst + dst_stride + 16); + hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r1_j24_sum, bias_val))), + tag8_8, dst + dst_stride + 24); + } + src += 2 * src_stride; + dst += 2 * dst_stride; + h -= 2; } } - src += src_stride; - dst += dst_stride; + } else { + hn::ScalableTag<int16_t> mul_tag; + hn::Rebind<uint8_t, decltype(mul_tag)> pixel_tag; + auto filter_0 = hn::Set(mul_tag, filter_x[3]); + auto filter_1 = hn::Set(mul_tag, filter_x[4]); + auto vw = hn::Lanes(mul_tag); + for (int i = 0; i < h; ++i) { + for (int j = 0; j < w; j += vw) { + auto src0 = hn::PromoteTo(mul_tag, hn::LoadU(pixel_tag, &src[j])); + auto src1 = hn::PromoteTo(mul_tag, hn::LoadU(pixel_tag, &src[j + 1])); + auto mulv = hn::RoundingShiftRight<FILTER_BITS>(src0 * filter_0 + + src1 * filter_1); + auto mulv_demoted = hn::DemoteTo(pixel_tag, mulv); + if (j + static_cast<int>(vw) > w) { + hn::StoreN(mulv_demoted, pixel_tag, &dst[j], w - j); + } else { + hn::StoreU(mulv_demoted, pixel_tag, &dst[j]); + } + } + src += src_stride; + dst += dst_stride; + } } } @@ -175,59 +345,286 @@ HWY_ATTR inline void ConvolveHoriz4Tap(const uint8_t *src, ptrdiff_t src_stride, uint8_t *dst, ptrdiff_t dst_stride, const int16_t *filter_x, int w, int h) { - hn::CappedTag<int16_t, 16> tag16; - hn::CappedTag<int16_t, 4> filter_tag; - auto f_vec = hn::LoadU(filter_tag, filter_x + 2); - // All filter values are even, halve to reduce intermediate precision - // requirements. - f_vec = hn::ShiftRight<1>(f_vec); + const bool can_use_optimized_path = + (w <= 32) && (filter_x[2] % 2 == 0) && (filter_x[3] % 2 == 0) && + (filter_x[4] % 2 == 0) && (filter_x[5] % 2 == 0); - if (w == 4) { - // Each iteration processes a 4x4 block - do { - auto src0 = LoadUnaligned4x4(tag16, src, src_stride); - auto src1 = LoadUnaligned4x4(tag16, src + 1, src_stride); - auto src2 = LoadUnaligned4x4(tag16, src + 2, src_stride); - auto src3 = LoadUnaligned4x4(tag16, src + 3, src_stride); - auto result = - Convolve4_8(tag16, filter_tag, src0, src1, src2, src3, f_vec); - StoreUnaligned4x4(tag16, dst, dst_stride, result); - h -= 4; - src += 4 * src_stride; - dst += 4 * dst_stride; - } while (h > 0); - } else if (w == 8) { - // Each iteration processes a 2x8 block - do { - auto src0 = LoadUnaligned2x8(tag16, src, src_stride); - auto src1 = LoadUnaligned2x8(tag16, src + 1, src_stride); - auto src2 = LoadUnaligned2x8(tag16, src + 2, src_stride); - auto src3 = LoadUnaligned2x8(tag16, src + 3, src_stride); - auto result = - Convolve4_8(tag16, filter_tag, src0, src1, src2, src3, f_vec); - StoreUnaligned2x8(tag16, dst, dst_stride, result); - h -= 2; - src += 2 * src_stride; - dst += 2 * dst_stride; - } while (h > 0); - } else if (w == 16) { - // One 1x16 block a time - do { - hn::Rebind<uint8_t, decltype(tag16)> tag8; - auto src0 = hn::PromoteTo(tag16, hn::LoadU(tag8, src)); - auto src1 = hn::PromoteTo(tag16, hn::LoadU(tag8, src + 1)); - auto src2 = hn::PromoteTo(tag16, hn::LoadU(tag8, src + 2)); - auto src3 = hn::PromoteTo(tag16, hn::LoadU(tag8, src + 3)); - auto result = - Convolve4_8(tag16, filter_tag, src0, src1, src2, src3, f_vec); - hn::StoreU(hn::DemoteTo(tag8, result), tag8, dst); - h--; - src += src_stride; - dst += dst_stride; - } while (h > 0); + if (can_use_optimized_path) { + hn::CappedTag<uint8_t, 16> tag8_16; + hn::CappedTag<int8_t, 16> tag_i8; + hn::CappedTag<int16_t, 8> tag16_8; + hn::CappedTag<uint8_t, 8> tag8_8; + hn::CappedTag<uint8_t, 4> tag8_4; + const auto bias_val = hn::Set(tag16_8, 1 << (FILTER_BITS - 2)); + + const auto shuffle_mask = hn::Dup128VecFromValues( + tag8_16, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8); + + const int8_t c2 = static_cast<int8_t>(filter_x[2] / 2); + const int8_t c3 = static_cast<int8_t>(filter_x[3] / 2); + const int8_t c4 = static_cast<int8_t>(filter_x[4] / 2); + const int8_t c5 = static_cast<int8_t>(filter_x[5] / 2); + + const auto coeff23 = hn::Dup128VecFromValues( + tag_i8, c2, c3, c2, c3, c2, c3, c2, c3, c2, c3, c2, c3, c2, c3, c2, c3); + + const auto coeff45 = hn::Dup128VecFromValues( + tag_i8, c4, c5, c4, c5, c4, c5, c4, c5, c4, c5, c4, c5, c4, c5, c4, c5); + + if (w == 4) { + while (h >= 2) { + auto r0_d0 = hn::LoadU(tag8_16, src + 0); + auto r0_d2 = hn::LoadU(tag8_16, src + 2); + + auto r1_d0 = hn::LoadU(tag8_16, src + src_stride + 0); + auto r1_d2 = hn::LoadU(tag8_16, src + src_stride + 2); + + auto r0_sum = hn::Add( + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_d0, shuffle_mask), coeff23), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_d2, shuffle_mask), coeff45)); + + auto r1_sum = hn::Add( + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_d0, shuffle_mask), coeff23), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_d2, shuffle_mask), coeff45)); + + hn::StoreU( + hn::LowerHalf(tag8_4, + hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r0_sum, bias_val)))), + tag8_4, dst); + hn::StoreU( + hn::LowerHalf(tag8_4, + hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r1_sum, bias_val)))), + tag8_4, dst + dst_stride); + + src += 2 * src_stride; + dst += 2 * dst_stride; + h -= 2; + } + } else if (w == 8) { + while (h >= 2) { + auto r0_d0 = hn::LoadU(tag8_16, src + 0); + auto r0_d2 = hn::LoadU(tag8_16, src + 2); + + auto r1_d0 = hn::LoadU(tag8_16, src + src_stride + 0); + auto r1_d2 = hn::LoadU(tag8_16, src + src_stride + 2); + + auto r0_sum = hn::Add( + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_d0, shuffle_mask), coeff23), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_d2, shuffle_mask), coeff45)); + + auto r1_sum = hn::Add( + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_d0, shuffle_mask), coeff23), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_d2, shuffle_mask), coeff45)); + + hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r0_sum, bias_val))), + tag8_8, dst); + hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r1_sum, bias_val))), + tag8_8, dst + dst_stride); + + src += 2 * src_stride; + dst += 2 * dst_stride; + h -= 2; + } + } else if (w == 16) { + while (h >= 2) { + auto r0_j0_d0 = hn::LoadU(tag8_16, src + 0); + auto r0_j0_d2 = hn::LoadU(tag8_16, src + 2); + + auto r0_j8_d0 = hn::LoadU(tag8_16, src + 8); + auto r0_j8_d2 = hn::LoadU(tag8_16, src + 10); + + auto r1_j0_d0 = hn::LoadU(tag8_16, src + src_stride + 0); + auto r1_j0_d2 = hn::LoadU(tag8_16, src + src_stride + 2); + + auto r1_j8_d0 = hn::LoadU(tag8_16, src + src_stride + 8); + auto r1_j8_d2 = hn::LoadU(tag8_16, src + src_stride + 10); + + auto r0_j0_sum = hn::Add( + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_j0_d0, shuffle_mask), coeff23), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_j0_d2, shuffle_mask), + coeff45)); + + auto r0_j8_sum = hn::Add( + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_j8_d0, shuffle_mask), coeff23), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_j8_d2, shuffle_mask), + coeff45)); + + auto r1_j0_sum = hn::Add( + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_j0_d0, shuffle_mask), coeff23), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_j0_d2, shuffle_mask), + coeff45)); + + auto r1_j8_sum = hn::Add( + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_j8_d0, shuffle_mask), coeff23), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_j8_d2, shuffle_mask), + coeff45)); + + hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r0_j0_sum, bias_val))), + tag8_8, dst + 0); + hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r0_j8_sum, bias_val))), + tag8_8, dst + 8); + hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r1_j0_sum, bias_val))), + tag8_8, dst + dst_stride + 0); + hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r1_j8_sum, bias_val))), + tag8_8, dst + dst_stride + 8); + + src += 2 * src_stride; + dst += 2 * dst_stride; + h -= 2; + } + } else if (w == 32) { + while (h >= 2) { + { + auto r0_j0_d0 = hn::LoadU(tag8_16, src + 0); + auto r0_j0_d2 = hn::LoadU(tag8_16, src + 2); + + auto r0_j8_d0 = hn::LoadU(tag8_16, src + 8); + auto r0_j8_d2 = hn::LoadU(tag8_16, src + 10); + + auto r1_j0_d0 = hn::LoadU(tag8_16, src + src_stride + 0); + auto r1_j0_d2 = hn::LoadU(tag8_16, src + src_stride + 2); + + auto r1_j8_d0 = hn::LoadU(tag8_16, src + src_stride + 8); + auto r1_j8_d2 = hn::LoadU(tag8_16, src + src_stride + 10); + + auto r0_j0_sum = + hn::Add(hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_j0_d0, shuffle_mask), + coeff23), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_j0_d2, shuffle_mask), + coeff45)); + + auto r0_j8_sum = + hn::Add(hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_j8_d0, shuffle_mask), + coeff23), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_j8_d2, shuffle_mask), + coeff45)); + + auto r1_j0_sum = + hn::Add(hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_j0_d0, shuffle_mask), + coeff23), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_j0_d2, shuffle_mask), + coeff45)); + + auto r1_j8_sum = + hn::Add(hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_j8_d0, shuffle_mask), + coeff23), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_j8_d2, shuffle_mask), + coeff45)); + + hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r0_j0_sum, bias_val))), + tag8_8, dst + 0); + hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r0_j8_sum, bias_val))), + tag8_8, dst + 8); + hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r1_j0_sum, bias_val))), + tag8_8, dst + dst_stride + 0); + hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r1_j8_sum, bias_val))), + tag8_8, dst + dst_stride + 8); + } + { + auto r0_j16_d0 = hn::LoadU(tag8_16, src + 16); + auto r0_j16_d2 = hn::LoadU(tag8_16, src + 18); + + auto r0_j24_d0 = hn::LoadU(tag8_16, src + 24); + auto r0_j24_d2 = hn::LoadU(tag8_16, src + 26); + + auto r1_j16_d0 = hn::LoadU(tag8_16, src + src_stride + 16); + auto r1_j16_d2 = hn::LoadU(tag8_16, src + src_stride + 18); + + auto r1_j24_d0 = hn::LoadU(tag8_16, src + src_stride + 24); + auto r1_j24_d2 = hn::LoadU(tag8_16, src + src_stride + 26); + + auto r0_j16_sum = hn::Add( + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_j16_d0, shuffle_mask), + coeff23), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_j16_d2, shuffle_mask), + coeff45)); + + auto r0_j24_sum = hn::Add( + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_j24_d0, shuffle_mask), + coeff23), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_j24_d2, shuffle_mask), + coeff45)); + + auto r1_j16_sum = hn::Add( + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_j16_d0, shuffle_mask), + coeff23), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_j16_d2, shuffle_mask), + coeff45)); + + auto r1_j24_sum = hn::Add( + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_j24_d0, shuffle_mask), + coeff23), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_j24_d2, shuffle_mask), + coeff45)); + + hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r0_j16_sum, bias_val))), + tag8_8, dst + 16); + hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r0_j24_sum, bias_val))), + tag8_8, dst + 24); + hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r1_j16_sum, bias_val))), + tag8_8, dst + dst_stride + 16); + hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r1_j24_sum, bias_val))), + tag8_8, dst + dst_stride + 24); + } + src += 2 * src_stride; + dst += 2 * dst_stride; + h -= 2; + } + } } else { hn::ScalableTag<int16_t> mul_tag; hn::Rebind<uint8_t, decltype(mul_tag)> pixel_tag; + hn::CappedTag<int16_t, 4> filter_tag; + auto f_vec = hn::LoadU(filter_tag, filter_x + 2); + f_vec = hn::ShiftRight<1>(f_vec); auto vw = hn::Lanes(mul_tag); for (int i = 0; i < h; ++i) { for (int j = 0; j < w; j += vw) { @@ -278,82 +675,447 @@ return hn::RoundingShiftRight<FILTER_BITS - 1>(res); } -DECLARE_ALIGNED(32, static const uint8_t, filt_global[]) = { - 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 0, 1, 1, - 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 2, 3, 3, 4, 4, 5, - 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 2, 3, 3, 4, 4, 5, 5, 6, 6, - 7, 7, 8, 8, 9, 9, 10, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, - 10, 11, 11, 12, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, - 12, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 6, 7, - 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13, 14 -}; - HWY_ATTR inline void ConvolveHoriz8Tap(const uint8_t *src, ptrdiff_t src_stride, uint8_t *dst, ptrdiff_t dst_stride, const int16_t *filter_x, int w, int h) { - hn::CappedTag<int16_t, 16> tag16; - hn::CappedTag<int16_t, 8> filter_tag; - auto f_vec = hn::LoadU(filter_tag, filter_x); - // All filter values are even, halve to reduce intermediate precision - // requirements. - f_vec = hn::ShiftRight<1>(f_vec); + const bool can_use_optimized_path = + (w <= 32) && (filter_x[0] % 2 == 0) && (filter_x[1] % 2 == 0) && + (filter_x[2] % 2 == 0) && (filter_x[3] % 2 == 0) && + (filter_x[4] % 2 == 0) && (filter_x[5] % 2 == 0) && + (filter_x[6] % 2 == 0) && (filter_x[7] % 2 == 0); - if (w == 4) { - do { - auto src0 = LoadUnaligned4x4(tag16, src, src_stride); - auto src1 = LoadUnaligned4x4(tag16, src + 1, src_stride); - auto src2 = LoadUnaligned4x4(tag16, src + 2, src_stride); - auto src3 = LoadUnaligned4x4(tag16, src + 3, src_stride); - auto src4 = LoadUnaligned4x4(tag16, src + 4, src_stride); - auto src5 = LoadUnaligned4x4(tag16, src + 5, src_stride); - auto src6 = LoadUnaligned4x4(tag16, src + 6, src_stride); - auto src7 = LoadUnaligned4x4(tag16, src + 7, src_stride); - auto result = Convolve8_8(tag16, filter_tag, src0, src1, src2, src3, src4, - src5, src6, src7, f_vec); - StoreUnaligned4x4(tag16, dst, dst_stride, result); - h -= 4; - src += 4 * src_stride; - dst += 4 * dst_stride; - } while (h > 0); - } else if (w == 8) { - // Each iteration processes a 2x8 block - do { - auto src0 = LoadUnaligned2x8(tag16, src, src_stride); - auto src1 = LoadUnaligned2x8(tag16, src + 1, src_stride); - auto src2 = LoadUnaligned2x8(tag16, src + 2, src_stride); - auto src3 = LoadUnaligned2x8(tag16, src + 3, src_stride); - auto src4 = LoadUnaligned2x8(tag16, src + 4, src_stride); - auto src5 = LoadUnaligned2x8(tag16, src + 5, src_stride); - auto src6 = LoadUnaligned2x8(tag16, src + 6, src_stride); - auto src7 = LoadUnaligned2x8(tag16, src + 7, src_stride); - auto result = Convolve8_8(tag16, filter_tag, src0, src1, src2, src3, src4, - src5, src6, src7, f_vec); - StoreUnaligned2x8(tag16, dst, dst_stride, result); - h -= 2; - src += 2 * src_stride; - dst += 2 * dst_stride; - } while (h > 0); - } else if (w == 16) { - // One 1x16 block a time - do { - hn::Rebind<uint8_t, decltype(tag16)> tag8; - auto src0 = hn::PromoteTo(tag16, hn::LoadU(tag8, src)); - auto src1 = hn::PromoteTo(tag16, hn::LoadU(tag8, src + 1)); - auto src2 = hn::PromoteTo(tag16, hn::LoadU(tag8, src + 2)); - auto src3 = hn::PromoteTo(tag16, hn::LoadU(tag8, src + 3)); - auto src4 = hn::PromoteTo(tag16, hn::LoadU(tag8, src + 4)); - auto src5 = hn::PromoteTo(tag16, hn::LoadU(tag8, src + 5)); - auto src6 = hn::PromoteTo(tag16, hn::LoadU(tag8, src + 6)); - auto src7 = hn::PromoteTo(tag16, hn::LoadU(tag8, src + 7)); - auto result = Convolve8_8(tag16, filter_tag, src0, src1, src2, src3, src4, - src5, src6, src7, f_vec); - hn::StoreU(hn::DemoteTo(tag8, result), tag8, dst); - h--; - src += src_stride; - dst += dst_stride; - } while (h > 0); + if (can_use_optimized_path) { + hn::CappedTag<uint8_t, 16> tag8_16; + hn::CappedTag<int8_t, 16> tag_i8; + hn::CappedTag<int16_t, 8> tag16_8; + hn::CappedTag<uint8_t, 8> tag8_8; + hn::CappedTag<uint8_t, 4> tag8_4; + const auto bias_val = hn::Set(tag16_8, 1 << (FILTER_BITS - 2)); + + const auto shuffle_mask = hn::Dup128VecFromValues( + tag8_16, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8); + + const int8_t c0 = static_cast<int8_t>(filter_x[0] / 2); + const int8_t c1 = static_cast<int8_t>(filter_x[1] / 2); + const int8_t c2 = static_cast<int8_t>(filter_x[2] / 2); + const int8_t c3 = static_cast<int8_t>(filter_x[3] / 2); + const int8_t c4 = static_cast<int8_t>(filter_x[4] / 2); + const int8_t c5 = static_cast<int8_t>(filter_x[5] / 2); + const int8_t c6 = static_cast<int8_t>(filter_x[6] / 2); + const int8_t c7 = static_cast<int8_t>(filter_x[7] / 2); + + const auto coeff01 = hn::Dup128VecFromValues( + tag_i8, c0, c1, c0, c1, c0, c1, c0, c1, c0, c1, c0, c1, c0, c1, c0, c1); + + const auto coeff23 = hn::Dup128VecFromValues( + tag_i8, c2, c3, c2, c3, c2, c3, c2, c3, c2, c3, c2, c3, c2, c3, c2, c3); + + const auto coeff45 = hn::Dup128VecFromValues( + tag_i8, c4, c5, c4, c5, c4, c5, c4, c5, c4, c5, c4, c5, c4, c5, c4, c5); + + const auto coeff67 = hn::Dup128VecFromValues( + tag_i8, c6, c7, c6, c7, c6, c7, c6, c7, c6, c7, c6, c7, c6, c7, c6, c7); + + if (w == 4) { + while (h >= 2) { + auto r0_d0 = hn::LoadU(tag8_16, src + 0); + auto r0_d2 = hn::LoadU(tag8_16, src + 2); + auto r0_d4 = hn::LoadU(tag8_16, src + 4); + auto r0_d6 = hn::LoadU(tag8_16, src + 6); + + auto r1_d0 = hn::LoadU(tag8_16, src + src_stride + 0); + auto r1_d2 = hn::LoadU(tag8_16, src + src_stride + 2); + auto r1_d4 = hn::LoadU(tag8_16, src + src_stride + 4); + auto r1_d6 = hn::LoadU(tag8_16, src + src_stride + 6); + + auto r0_sum = hn::Add( + hn::Add(hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_d0, shuffle_mask), + coeff01), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_d2, shuffle_mask), + coeff23)), + hn::Add(hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_d4, shuffle_mask), + coeff45), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_d6, shuffle_mask), + coeff67))); + + auto r1_sum = hn::Add( + hn::Add(hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_d0, shuffle_mask), + coeff01), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_d2, shuffle_mask), + coeff23)), + hn::Add(hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_d4, shuffle_mask), + coeff45), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_d6, shuffle_mask), + coeff67))); + + hn::StoreU( + hn::LowerHalf(tag8_4, + hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r0_sum, bias_val)))), + tag8_4, dst); + hn::StoreU( + hn::LowerHalf(tag8_4, + hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r1_sum, bias_val)))), + tag8_4, dst + dst_stride); + + src += 2 * src_stride; + dst += 2 * dst_stride; + h -= 2; + } + } else if (w == 8) { + while (h >= 2) { + auto r0_d0 = hn::LoadU(tag8_16, src + 0); + auto r0_d2 = hn::LoadU(tag8_16, src + 2); + auto r0_d4 = hn::LoadU(tag8_16, src + 4); + auto r0_d6 = hn::LoadU(tag8_16, src + 6); + + auto r1_d0 = hn::LoadU(tag8_16, src + src_stride + 0); + auto r1_d2 = hn::LoadU(tag8_16, src + src_stride + 2); + auto r1_d4 = hn::LoadU(tag8_16, src + src_stride + 4); + auto r1_d6 = hn::LoadU(tag8_16, src + src_stride + 6); + + auto r0_sum = hn::Add( + hn::Add(hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_d0, shuffle_mask), + coeff01), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_d2, shuffle_mask), + coeff23)), + hn::Add(hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_d4, shuffle_mask), + coeff45), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_d6, shuffle_mask), + coeff67))); + + auto r1_sum = hn::Add( + hn::Add(hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_d0, shuffle_mask), + coeff01), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_d2, shuffle_mask), + coeff23)), + hn::Add(hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_d4, shuffle_mask), + coeff45), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_d6, shuffle_mask), + coeff67))); + + hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r0_sum, bias_val))), + tag8_8, dst); + hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r1_sum, bias_val))), + tag8_8, dst + dst_stride); + + src += 2 * src_stride; + dst += 2 * dst_stride; + h -= 2; + } + } else if (w == 16) { + while (h >= 2) { + auto r0_j0_d0 = hn::LoadU(tag8_16, src + 0); + auto r0_j0_d2 = hn::LoadU(tag8_16, src + 2); + auto r0_j0_d4 = hn::LoadU(tag8_16, src + 4); + auto r0_j0_d6 = hn::LoadU(tag8_16, src + 6); + + auto r0_j8_d0 = hn::LoadU(tag8_16, src + 8); + auto r0_j8_d2 = hn::LoadU(tag8_16, src + 10); + auto r0_j8_d4 = hn::LoadU(tag8_16, src + 12); + auto r0_j8_d6 = hn::LoadU(tag8_16, src + 14); + + auto r1_j0_d0 = hn::LoadU(tag8_16, src + src_stride + 0); + auto r1_j0_d2 = hn::LoadU(tag8_16, src + src_stride + 2); + auto r1_j0_d4 = hn::LoadU(tag8_16, src + src_stride + 4); + auto r1_j0_d6 = hn::LoadU(tag8_16, src + src_stride + 6); + + auto r1_j8_d0 = hn::LoadU(tag8_16, src + src_stride + 8); + auto r1_j8_d2 = hn::LoadU(tag8_16, src + src_stride + 10); + auto r1_j8_d4 = hn::LoadU(tag8_16, src + src_stride + 12); + auto r1_j8_d6 = hn::LoadU(tag8_16, src + src_stride + 14); + + auto r0_j0_sum = hn::Add( + hn::Add(hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_j0_d0, shuffle_mask), + coeff01), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_j0_d2, shuffle_mask), + coeff23)), + hn::Add(hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_j0_d4, shuffle_mask), + coeff45), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_j0_d6, shuffle_mask), + coeff67))); + + auto r0_j8_sum = hn::Add( + hn::Add(hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_j8_d0, shuffle_mask), + coeff01), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_j8_d2, shuffle_mask), + coeff23)), + hn::Add(hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_j8_d4, shuffle_mask), + coeff45), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_j8_d6, shuffle_mask), + coeff67))); + + auto r1_j0_sum = hn::Add( + hn::Add(hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_j0_d0, shuffle_mask), + coeff01), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_j0_d2, shuffle_mask), + coeff23)), + hn::Add(hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_j0_d4, shuffle_mask), + coeff45), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_j0_d6, shuffle_mask), + coeff67))); + + auto r1_j8_sum = hn::Add( + hn::Add(hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_j8_d0, shuffle_mask), + coeff01), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_j8_d2, shuffle_mask), + coeff23)), + hn::Add(hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_j8_d4, shuffle_mask), + coeff45), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_j8_d6, shuffle_mask), + coeff67))); + + hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r0_j0_sum, bias_val))), + tag8_8, dst + 0); + hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r0_j8_sum, bias_val))), + tag8_8, dst + 8); + hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r1_j0_sum, bias_val))), + tag8_8, dst + dst_stride + 0); + hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r1_j8_sum, bias_val))), + tag8_8, dst + dst_stride + 8); + + src += 2 * src_stride; + dst += 2 * dst_stride; + h -= 2; + } + } else if (w == 32) { + while (h >= 2) { + { + auto r0_j0_d0 = hn::LoadU(tag8_16, src + 0); + auto r0_j0_d2 = hn::LoadU(tag8_16, src + 2); + auto r0_j0_d4 = hn::LoadU(tag8_16, src + 4); + auto r0_j0_d6 = hn::LoadU(tag8_16, src + 6); + + auto r0_j8_d0 = hn::LoadU(tag8_16, src + 8); + auto r0_j8_d2 = hn::LoadU(tag8_16, src + 10); + auto r0_j8_d4 = hn::LoadU(tag8_16, src + 12); + auto r0_j8_d6 = hn::LoadU(tag8_16, src + 14); + + auto r1_j0_d0 = hn::LoadU(tag8_16, src + src_stride + 0); + auto r1_j0_d2 = hn::LoadU(tag8_16, src + src_stride + 2); + auto r1_j0_d4 = hn::LoadU(tag8_16, src + src_stride + 4); + auto r1_j0_d6 = hn::LoadU(tag8_16, src + src_stride + 6); + + auto r1_j8_d0 = hn::LoadU(tag8_16, src + src_stride + 8); + auto r1_j8_d2 = hn::LoadU(tag8_16, src + src_stride + 10); + auto r1_j8_d4 = hn::LoadU(tag8_16, src + src_stride + 12); + auto r1_j8_d6 = hn::LoadU(tag8_16, src + src_stride + 14); + + auto r0_j0_sum = hn::Add( + hn::Add(hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_j0_d0, shuffle_mask), + coeff01), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_j0_d2, shuffle_mask), + coeff23)), + hn::Add(hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_j0_d4, shuffle_mask), + coeff45), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_j0_d6, shuffle_mask), + coeff67))); + + auto r0_j8_sum = hn::Add( + hn::Add(hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_j8_d0, shuffle_mask), + coeff01), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_j8_d2, shuffle_mask), + coeff23)), + hn::Add(hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_j8_d4, shuffle_mask), + coeff45), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_j8_d6, shuffle_mask), + coeff67))); + + auto r1_j0_sum = hn::Add( + hn::Add(hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_j0_d0, shuffle_mask), + coeff01), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_j0_d2, shuffle_mask), + coeff23)), + hn::Add(hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_j0_d4, shuffle_mask), + coeff45), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_j0_d6, shuffle_mask), + coeff67))); + + auto r1_j8_sum = hn::Add( + hn::Add(hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_j8_d0, shuffle_mask), + coeff01), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_j8_d2, shuffle_mask), + coeff23)), + hn::Add(hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_j8_d4, shuffle_mask), + coeff45), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_j8_d6, shuffle_mask), + coeff67))); + + hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r0_j0_sum, bias_val))), + tag8_8, dst + 0); + hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r0_j8_sum, bias_val))), + tag8_8, dst + 8); + hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r1_j0_sum, bias_val))), + tag8_8, dst + dst_stride + 0); + hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r1_j8_sum, bias_val))), + tag8_8, dst + dst_stride + 8); + } + { + auto r0_j16_d0 = hn::LoadU(tag8_16, src + 16); + auto r0_j16_d2 = hn::LoadU(tag8_16, src + 18); + auto r0_j16_d4 = hn::LoadU(tag8_16, src + 20); + auto r0_j16_d6 = hn::LoadU(tag8_16, src + 22); + + auto r0_j24_d0 = hn::LoadU(tag8_16, src + 24); + auto r0_j24_d2 = hn::LoadU(tag8_16, src + 26); + auto r0_j24_d4 = hn::LoadU(tag8_16, src + 28); + auto r0_j24_d6 = hn::LoadU(tag8_16, src + 30); + + auto r1_j16_d0 = hn::LoadU(tag8_16, src + src_stride + 16); + auto r1_j16_d2 = hn::LoadU(tag8_16, src + src_stride + 18); + auto r1_j16_d4 = hn::LoadU(tag8_16, src + src_stride + 20); + auto r1_j16_d6 = hn::LoadU(tag8_16, src + src_stride + 22); + + auto r1_j24_d0 = hn::LoadU(tag8_16, src + src_stride + 24); + auto r1_j24_d2 = hn::LoadU(tag8_16, src + src_stride + 26); + auto r1_j24_d4 = hn::LoadU(tag8_16, src + src_stride + 28); + auto r1_j24_d6 = hn::LoadU(tag8_16, src + src_stride + 30); + + auto r0_j16_sum = hn::Add( + hn::Add( + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_j16_d0, shuffle_mask), + coeff01), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_j16_d2, shuffle_mask), + coeff23)), + hn::Add( + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_j16_d4, shuffle_mask), + coeff45), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_j16_d6, shuffle_mask), + coeff67))); + + auto r0_j24_sum = hn::Add( + hn::Add( + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_j24_d0, shuffle_mask), + coeff01), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_j24_d2, shuffle_mask), + coeff23)), + hn::Add( + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_j24_d4, shuffle_mask), + coeff45), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r0_j24_d6, shuffle_mask), + coeff67))); + + auto r1_j16_sum = hn::Add( + hn::Add( + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_j16_d0, shuffle_mask), + coeff01), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_j16_d2, shuffle_mask), + coeff23)), + hn::Add( + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_j16_d4, shuffle_mask), + coeff45), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_j16_d6, shuffle_mask), + coeff67))); + + auto r1_j24_sum = hn::Add( + hn::Add( + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_j24_d0, shuffle_mask), + coeff01), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_j24_d2, shuffle_mask), + coeff23)), + hn::Add( + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_j24_d4, shuffle_mask), + coeff45), + hn::SatWidenMulPairwiseAdd( + tag16_8, hn::TableLookupBytes(r1_j24_d6, shuffle_mask), + coeff67))); + + hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r0_j16_sum, bias_val))), + tag8_8, dst + 16); + hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r0_j24_sum, bias_val))), + tag8_8, dst + 24); + hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r1_j16_sum, bias_val))), + tag8_8, dst + dst_stride + 16); + hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>( + hn::Add(r1_j24_sum, bias_val))), + tag8_8, dst + dst_stride + 24); + } + src += 2 * src_stride; + dst += 2 * dst_stride; + h -= 2; + } + } } else { - // This tag will have 32 lanes (for avx512) or 16 lanes (for avx2) + hn::CappedTag<int16_t, 8> filter_tag; + auto f_vec = hn::LoadU(filter_tag, filter_x); + f_vec = hn::ShiftRight<1>(f_vec); hn::ScalableTag<int16_t> mul_tag; hn::Rebind<uint8_t, decltype(mul_tag)> pixel_tag; auto vw = hn::Lanes(mul_tag);
diff --git a/aom_dsp/x86/convolve_hwy_avx512.cc b/aom_dsp/x86/convolve_hwy_avx512.cc index c1aa904..6255704 100644 --- a/aom_dsp/x86/convolve_hwy_avx512.cc +++ b/aom_dsp/x86/convolve_hwy_avx512.cc
@@ -32,10 +32,11 @@ const int16_t *filter_x, int x_step_q4, const int16_t *filter_y, int y_step_q4, int w, int h) { - // Fallback to AVX2 for small block sizes (w <= 16) where the handwritten - // AVX2 implementation was measured to be faster than the Highway AVX512 - // implementation in benchmarks. - if (w <= 16) { + // 16x32, 32x32 and 64x32 blocks show ~10% slow down compared with avx2 with + // significant speed up for all other blocks. Fall back to avx2 for wx32 + // blocks. + // TODO: jianj - Investigate and optimize for wx32 blocks. + if (h == 32) { aom_convolve8_horiz_avx2(src, src_stride, dst, dst_stride, filter_x, x_step_q4, filter_y, y_step_q4, w, h); } else {