| // Copyright 2019 Google LLC |
| // SPDX-License-Identifier: Apache-2.0 |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| // 512-bit AVX512 vectors and operations. |
| // External include guard in highway.h - see comment there. |
| |
| // WARNING: most operations do not cross 128-bit block boundaries. In |
| // particular, "Broadcast", pack and zip behavior may be surprising. |
| |
| // Must come before HWY_DIAGNOSTICS and HWY_COMPILER_CLANGCL |
| #include "third_party/highway/hwy/base.h" |
| |
| // Avoid uninitialized warnings in GCC's avx512fintrin.h - see |
| // https://github.com/google/highway/issues/710) |
| HWY_DIAGNOSTICS(push) |
| #if HWY_COMPILER_GCC_ACTUAL |
| HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") |
| HWY_DIAGNOSTICS_OFF(disable : 4701 4703 6001 26494, |
| ignored "-Wmaybe-uninitialized") |
| #endif |
| |
| #include <immintrin.h> // AVX2+ |
| |
| #if HWY_COMPILER_CLANGCL |
| // Including <immintrin.h> should be enough, but Clang's headers helpfully skip |
| // including these headers when _MSC_VER is defined, like when using clang-cl. |
| // Include these directly here. |
| // clang-format off |
| #include <smmintrin.h> |
| |
| #include <avxintrin.h> |
| // avxintrin defines __m256i and must come before avx2intrin. |
| #include <avx2intrin.h> |
| #include <f16cintrin.h> |
| #include <fmaintrin.h> |
| |
| #include <avx512fintrin.h> |
| #include <avx512vlintrin.h> |
| #include <avx512bwintrin.h> |
| #include <avx512vlbwintrin.h> |
| #include <avx512dqintrin.h> |
| #include <avx512vldqintrin.h> |
| #include <avx512cdintrin.h> |
| #include <avx512vlcdintrin.h> |
| |
| #if HWY_TARGET <= HWY_AVX3_DL |
| #include <avx512bitalgintrin.h> |
| #include <avx512vlbitalgintrin.h> |
| #include <avx512vbmiintrin.h> |
| #include <avx512vbmivlintrin.h> |
| #include <avx512vbmi2intrin.h> |
| #include <avx512vlvbmi2intrin.h> |
| #include <avx512vpopcntdqintrin.h> |
| #include <avx512vpopcntdqvlintrin.h> |
| #include <avx512vnniintrin.h> |
| #include <avx512vlvnniintrin.h> |
| // Must come after avx512fintrin, else will not define 512-bit intrinsics. |
| #include <vaesintrin.h> |
| #include <vpclmulqdqintrin.h> |
| #include <gfniintrin.h> |
| #endif // HWY_TARGET <= HWY_AVX3_DL |
| |
| #if HWY_TARGET <= HWY_AVX3_SPR |
| #include <avx512fp16intrin.h> |
| #include <avx512vlfp16intrin.h> |
| #endif // HWY_TARGET <= HWY_AVX3_SPR |
| |
| // clang-format on |
| #endif // HWY_COMPILER_CLANGCL |
| |
| // For half-width vectors. Already includes base.h and shared-inl.h. |
| #include "third_party/highway/hwy/ops/x86_256-inl.h" |
| |
| HWY_BEFORE_NAMESPACE(); |
| namespace hwy { |
| namespace HWY_NAMESPACE { |
| |
| namespace detail { |
| |
| template <typename T> |
| struct Raw512 { |
| using type = __m512i; |
| }; |
| #if HWY_HAVE_FLOAT16 |
| template <> |
| struct Raw512<float16_t> { |
| using type = __m512h; |
| }; |
| #endif // HWY_HAVE_FLOAT16 |
| template <> |
| struct Raw512<float> { |
| using type = __m512; |
| }; |
| template <> |
| struct Raw512<double> { |
| using type = __m512d; |
| }; |
| |
| // Template arg: sizeof(lane type) |
| template <size_t size> |
| struct RawMask512 {}; |
| template <> |
| struct RawMask512<1> { |
| using type = __mmask64; |
| }; |
| template <> |
| struct RawMask512<2> { |
| using type = __mmask32; |
| }; |
| template <> |
| struct RawMask512<4> { |
| using type = __mmask16; |
| }; |
| template <> |
| struct RawMask512<8> { |
| using type = __mmask8; |
| }; |
| |
| } // namespace detail |
| |
| template <typename T> |
| class Vec512 { |
| using Raw = typename detail::Raw512<T>::type; |
| |
| public: |
| using PrivateT = T; // only for DFromV |
| static constexpr size_t kPrivateN = 64 / sizeof(T); // only for DFromV |
| |
| // Compound assignment. Only usable if there is a corresponding non-member |
| // binary operator overload. For example, only f32 and f64 support division. |
| HWY_INLINE Vec512& operator*=(const Vec512 other) { |
| return *this = (*this * other); |
| } |
| HWY_INLINE Vec512& operator/=(const Vec512 other) { |
| return *this = (*this / other); |
| } |
| HWY_INLINE Vec512& operator+=(const Vec512 other) { |
| return *this = (*this + other); |
| } |
| HWY_INLINE Vec512& operator-=(const Vec512 other) { |
| return *this = (*this - other); |
| } |
| HWY_INLINE Vec512& operator%=(const Vec512 other) { |
| return *this = (*this % other); |
| } |
| HWY_INLINE Vec512& operator&=(const Vec512 other) { |
| return *this = (*this & other); |
| } |
| HWY_INLINE Vec512& operator|=(const Vec512 other) { |
| return *this = (*this | other); |
| } |
| HWY_INLINE Vec512& operator^=(const Vec512 other) { |
| return *this = (*this ^ other); |
| } |
| |
| Raw raw; |
| }; |
| |
| // Mask register: one bit per lane. |
| template <typename T> |
| struct Mask512 { |
| using Raw = typename detail::RawMask512<sizeof(T)>::type; |
| |
| using PrivateT = T; // only for DFromM |
| static constexpr size_t kPrivateN = 64 / sizeof(T); // only for DFromM |
| |
| Raw raw; |
| }; |
| |
| template <typename T> |
| using Full512 = Simd<T, 64 / sizeof(T), 0>; |
| |
| // ------------------------------ BitCast |
| |
| namespace detail { |
| |
| HWY_INLINE __m512i BitCastToInteger(__m512i v) { return v; } |
| #if HWY_HAVE_FLOAT16 |
| HWY_INLINE __m512i BitCastToInteger(__m512h v) { |
| return _mm512_castph_si512(v); |
| } |
| #endif // HWY_HAVE_FLOAT16 |
| HWY_INLINE __m512i BitCastToInteger(__m512 v) { return _mm512_castps_si512(v); } |
| HWY_INLINE __m512i BitCastToInteger(__m512d v) { |
| return _mm512_castpd_si512(v); |
| } |
| |
| #if HWY_AVX3_HAVE_F32_TO_BF16C |
| HWY_INLINE __m512i BitCastToInteger(__m512bh v) { |
| // Need to use reinterpret_cast on GCC/Clang or BitCastScalar on MSVC to |
| // bit cast a __m512bh to a __m512i as there is currently no intrinsic |
| // available (as of GCC 13 and Clang 17) that can bit cast a __m512bh vector |
| // to a __m512i vector |
| |
| #if HWY_COMPILER_GCC || HWY_COMPILER_CLANG |
| // On GCC or Clang, use reinterpret_cast to bit cast a __m512bh to a __m512i |
| return reinterpret_cast<__m512i>(v); |
| #else |
| // On MSVC, use BitCastScalar to bit cast a __m512bh to a __m512i as MSVC does |
| // not allow reinterpret_cast, static_cast, or a C-style cast to be used to |
| // bit cast from one AVX vector type to a different AVX vector type |
| return BitCastScalar<__m512i>(v); |
| #endif // HWY_COMPILER_GCC || HWY_COMPILER_CLANG |
| } |
| #endif // HWY_AVX3_HAVE_F32_TO_BF16C |
| |
| template <typename T> |
| HWY_INLINE Vec512<uint8_t> BitCastToByte(Vec512<T> v) { |
| return Vec512<uint8_t>{BitCastToInteger(v.raw)}; |
| } |
| |
| // Cannot rely on function overloading because return types differ. |
| template <typename T> |
| struct BitCastFromInteger512 { |
| HWY_INLINE __m512i operator()(__m512i v) { return v; } |
| }; |
| #if HWY_HAVE_FLOAT16 |
| template <> |
| struct BitCastFromInteger512<float16_t> { |
| HWY_INLINE __m512h operator()(__m512i v) { return _mm512_castsi512_ph(v); } |
| }; |
| #endif // HWY_HAVE_FLOAT16 |
| template <> |
| struct BitCastFromInteger512<float> { |
| HWY_INLINE __m512 operator()(__m512i v) { return _mm512_castsi512_ps(v); } |
| }; |
| template <> |
| struct BitCastFromInteger512<double> { |
| HWY_INLINE __m512d operator()(__m512i v) { return _mm512_castsi512_pd(v); } |
| }; |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64)> |
| HWY_INLINE VFromD<D> BitCastFromByte(D /* tag */, Vec512<uint8_t> v) { |
| return VFromD<D>{BitCastFromInteger512<TFromD<D>>()(v.raw)}; |
| } |
| |
| } // namespace detail |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), typename FromT> |
| HWY_API VFromD<D> BitCast(D d, Vec512<FromT> v) { |
| return detail::BitCastFromByte(d, detail::BitCastToByte(v)); |
| } |
| |
| // ------------------------------ Set |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 1)> |
| HWY_API VFromD<D> Set(D /* tag */, TFromD<D> t) { |
| return VFromD<D>{_mm512_set1_epi8(static_cast<char>(t))}; // NOLINT |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI16_D(D)> |
| HWY_API VFromD<D> Set(D /* tag */, TFromD<D> t) { |
| return VFromD<D>{_mm512_set1_epi16(static_cast<short>(t))}; // NOLINT |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI32_D(D)> |
| HWY_API VFromD<D> Set(D /* tag */, TFromD<D> t) { |
| return VFromD<D>{_mm512_set1_epi32(static_cast<int>(t))}; |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI64_D(D)> |
| HWY_API VFromD<D> Set(D /* tag */, TFromD<D> t) { |
| return VFromD<D>{_mm512_set1_epi64(static_cast<long long>(t))}; // NOLINT |
| } |
| // bfloat16_t is handled by x86_128-inl.h. |
| #if HWY_HAVE_FLOAT16 |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F16_D(D)> |
| HWY_API Vec512<float16_t> Set(D /* tag */, float16_t t) { |
| return Vec512<float16_t>{_mm512_set1_ph(t)}; |
| } |
| #endif // HWY_HAVE_FLOAT16 |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> |
| HWY_API Vec512<float> Set(D /* tag */, float t) { |
| return Vec512<float>{_mm512_set1_ps(t)}; |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> |
| HWY_API Vec512<double> Set(D /* tag */, double t) { |
| return Vec512<double>{_mm512_set1_pd(t)}; |
| } |
| |
| // ------------------------------ Zero (Set) |
| |
| // GCC pre-9.1 lacked setzero, so use Set instead. |
| #if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 900 |
| |
| // Cannot use VFromD here because it is defined in terms of Zero. |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_SPECIAL_FLOAT_D(D)> |
| HWY_API Vec512<TFromD<D>> Zero(D d) { |
| return Set(d, TFromD<D>{0}); |
| } |
| // BitCast is defined below, but the Raw type is the same, so use that. |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_BF16_D(D)> |
| HWY_API Vec512<bfloat16_t> Zero(D /* tag */) { |
| const RebindToUnsigned<D> du; |
| return Vec512<bfloat16_t>{Set(du, 0).raw}; |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F16_D(D)> |
| HWY_API Vec512<float16_t> Zero(D /* tag */) { |
| const RebindToUnsigned<D> du; |
| return Vec512<float16_t>{Set(du, 0).raw}; |
| } |
| |
| #else |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(D)> |
| HWY_API Vec512<TFromD<D>> Zero(D /* tag */) { |
| return Vec512<TFromD<D>>{_mm512_setzero_si512()}; |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_BF16_D(D)> |
| HWY_API Vec512<bfloat16_t> Zero(D /* tag */) { |
| return Vec512<bfloat16_t>{_mm512_setzero_si512()}; |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F16_D(D)> |
| HWY_API Vec512<float16_t> Zero(D /* tag */) { |
| #if HWY_HAVE_FLOAT16 |
| return Vec512<float16_t>{_mm512_setzero_ph()}; |
| #else |
| return Vec512<float16_t>{_mm512_setzero_si512()}; |
| #endif |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> |
| HWY_API Vec512<float> Zero(D /* tag */) { |
| return Vec512<float>{_mm512_setzero_ps()}; |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> |
| HWY_API Vec512<double> Zero(D /* tag */) { |
| return Vec512<double>{_mm512_setzero_pd()}; |
| } |
| |
| #endif // HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 900 |
| |
| // ------------------------------ Undefined |
| |
| HWY_DIAGNOSTICS(push) |
| HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") |
| |
| // Returns a vector with uninitialized elements. |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(D)> |
| HWY_API Vec512<TFromD<D>> Undefined(D /* tag */) { |
| // Available on Clang 6.0, GCC 6.2, ICC 16.03, MSVC 19.14. All but ICC |
| // generate an XOR instruction. |
| return Vec512<TFromD<D>>{_mm512_undefined_epi32()}; |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_BF16_D(D)> |
| HWY_API Vec512<bfloat16_t> Undefined(D /* tag */) { |
| return Vec512<bfloat16_t>{_mm512_undefined_epi32()}; |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F16_D(D)> |
| HWY_API Vec512<float16_t> Undefined(D /* tag */) { |
| #if HWY_HAVE_FLOAT16 |
| return Vec512<float16_t>{_mm512_undefined_ph()}; |
| #else |
| return Vec512<float16_t>{_mm512_undefined_epi32()}; |
| #endif |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> |
| HWY_API Vec512<float> Undefined(D /* tag */) { |
| return Vec512<float>{_mm512_undefined_ps()}; |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> |
| HWY_API Vec512<double> Undefined(D /* tag */) { |
| return Vec512<double>{_mm512_undefined_pd()}; |
| } |
| |
| HWY_DIAGNOSTICS(pop) |
| |
| // ------------------------------ ResizeBitCast |
| |
| // 64-byte vector to 16-byte vector |
| template <class D, class FromV, HWY_IF_V_SIZE_V(FromV, 64), |
| HWY_IF_V_SIZE_D(D, 16)> |
| HWY_API VFromD<D> ResizeBitCast(D d, FromV v) { |
| return BitCast(d, Vec128<uint8_t>{_mm512_castsi512_si128( |
| BitCast(Full512<uint8_t>(), v).raw)}); |
| } |
| |
| // <= 16-byte vector to 64-byte vector |
| template <class D, class FromV, HWY_IF_V_SIZE_LE_V(FromV, 16), |
| HWY_IF_V_SIZE_D(D, 64)> |
| HWY_API VFromD<D> ResizeBitCast(D d, FromV v) { |
| return BitCast(d, Vec512<uint8_t>{_mm512_castsi128_si512( |
| ResizeBitCast(Full128<uint8_t>(), v).raw)}); |
| } |
| |
| // 32-byte vector to 64-byte vector |
| template <class D, class FromV, HWY_IF_V_SIZE_V(FromV, 32), |
| HWY_IF_V_SIZE_D(D, 64)> |
| HWY_API VFromD<D> ResizeBitCast(D d, FromV v) { |
| return BitCast(d, Vec512<uint8_t>{_mm512_castsi256_si512( |
| BitCast(Full256<uint8_t>(), v).raw)}); |
| } |
| |
| // ------------------------------ Dup128VecFromValues |
| |
| template <class D, HWY_IF_UI8_D(D), HWY_IF_V_SIZE_D(D, 64)> |
| HWY_API VFromD<D> Dup128VecFromValues(D d, TFromD<D> t0, TFromD<D> t1, |
| TFromD<D> t2, TFromD<D> t3, TFromD<D> t4, |
| TFromD<D> t5, TFromD<D> t6, TFromD<D> t7, |
| TFromD<D> t8, TFromD<D> t9, TFromD<D> t10, |
| TFromD<D> t11, TFromD<D> t12, |
| TFromD<D> t13, TFromD<D> t14, |
| TFromD<D> t15) { |
| #if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 900 |
| // Missing set_epi8/16. |
| return BroadcastBlock<0>(ResizeBitCast( |
| d, Dup128VecFromValues(Full128<TFromD<D>>(), t0, t1, t2, t3, t4, t5, t6, |
| t7, t8, t9, t10, t11, t12, t13, t14, t15))); |
| #else |
| (void)d; |
| // Need to use _mm512_set_epi8 as there is no _mm512_setr_epi8 intrinsic |
| // available |
| return VFromD<D>{_mm512_set_epi8( |
| static_cast<char>(t15), static_cast<char>(t14), static_cast<char>(t13), |
| static_cast<char>(t12), static_cast<char>(t11), static_cast<char>(t10), |
| static_cast<char>(t9), static_cast<char>(t8), static_cast<char>(t7), |
| static_cast<char>(t6), static_cast<char>(t5), static_cast<char>(t4), |
| static_cast<char>(t3), static_cast<char>(t2), static_cast<char>(t1), |
| static_cast<char>(t0), static_cast<char>(t15), static_cast<char>(t14), |
| static_cast<char>(t13), static_cast<char>(t12), static_cast<char>(t11), |
| static_cast<char>(t10), static_cast<char>(t9), static_cast<char>(t8), |
| static_cast<char>(t7), static_cast<char>(t6), static_cast<char>(t5), |
| static_cast<char>(t4), static_cast<char>(t3), static_cast<char>(t2), |
| static_cast<char>(t1), static_cast<char>(t0), static_cast<char>(t15), |
| static_cast<char>(t14), static_cast<char>(t13), static_cast<char>(t12), |
| static_cast<char>(t11), static_cast<char>(t10), static_cast<char>(t9), |
| static_cast<char>(t8), static_cast<char>(t7), static_cast<char>(t6), |
| static_cast<char>(t5), static_cast<char>(t4), static_cast<char>(t3), |
| static_cast<char>(t2), static_cast<char>(t1), static_cast<char>(t0), |
| static_cast<char>(t15), static_cast<char>(t14), static_cast<char>(t13), |
| static_cast<char>(t12), static_cast<char>(t11), static_cast<char>(t10), |
| static_cast<char>(t9), static_cast<char>(t8), static_cast<char>(t7), |
| static_cast<char>(t6), static_cast<char>(t5), static_cast<char>(t4), |
| static_cast<char>(t3), static_cast<char>(t2), static_cast<char>(t1), |
| static_cast<char>(t0))}; |
| #endif |
| } |
| |
| template <class D, HWY_IF_UI16_D(D), HWY_IF_V_SIZE_D(D, 64)> |
| HWY_API VFromD<D> Dup128VecFromValues(D d, TFromD<D> t0, TFromD<D> t1, |
| TFromD<D> t2, TFromD<D> t3, TFromD<D> t4, |
| TFromD<D> t5, TFromD<D> t6, |
| TFromD<D> t7) { |
| #if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 900 |
| // Missing set_epi8/16. |
| return BroadcastBlock<0>( |
| ResizeBitCast(d, Dup128VecFromValues(Full128<TFromD<D>>(), t0, t1, t2, t3, |
| t4, t5, t6, t7))); |
| #else |
| (void)d; |
| // Need to use _mm512_set_epi16 as there is no _mm512_setr_epi16 intrinsic |
| // available |
| return VFromD<D>{ |
| _mm512_set_epi16(static_cast<int16_t>(t7), static_cast<int16_t>(t6), |
| static_cast<int16_t>(t5), static_cast<int16_t>(t4), |
| static_cast<int16_t>(t3), static_cast<int16_t>(t2), |
| static_cast<int16_t>(t1), static_cast<int16_t>(t0), |
| static_cast<int16_t>(t7), static_cast<int16_t>(t6), |
| static_cast<int16_t>(t5), static_cast<int16_t>(t4), |
| static_cast<int16_t>(t3), static_cast<int16_t>(t2), |
| static_cast<int16_t>(t1), static_cast<int16_t>(t0), |
| static_cast<int16_t>(t7), static_cast<int16_t>(t6), |
| static_cast<int16_t>(t5), static_cast<int16_t>(t4), |
| static_cast<int16_t>(t3), static_cast<int16_t>(t2), |
| static_cast<int16_t>(t1), static_cast<int16_t>(t0), |
| static_cast<int16_t>(t7), static_cast<int16_t>(t6), |
| static_cast<int16_t>(t5), static_cast<int16_t>(t4), |
| static_cast<int16_t>(t3), static_cast<int16_t>(t2), |
| static_cast<int16_t>(t1), static_cast<int16_t>(t0))}; |
| #endif |
| } |
| |
| #if HWY_HAVE_FLOAT16 |
| template <class D, HWY_IF_F16_D(D), HWY_IF_V_SIZE_D(D, 64)> |
| HWY_API VFromD<D> Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1, |
| TFromD<D> t2, TFromD<D> t3, TFromD<D> t4, |
| TFromD<D> t5, TFromD<D> t6, |
| TFromD<D> t7) { |
| return VFromD<D>{_mm512_setr_ph(t0, t1, t2, t3, t4, t5, t6, t7, t0, t1, t2, |
| t3, t4, t5, t6, t7, t0, t1, t2, t3, t4, t5, |
| t6, t7, t0, t1, t2, t3, t4, t5, t6, t7)}; |
| } |
| #endif |
| |
| template <class D, HWY_IF_UI32_D(D), HWY_IF_V_SIZE_D(D, 64)> |
| HWY_API VFromD<D> Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1, |
| TFromD<D> t2, TFromD<D> t3) { |
| return VFromD<D>{ |
| _mm512_setr_epi32(static_cast<int32_t>(t0), static_cast<int32_t>(t1), |
| static_cast<int32_t>(t2), static_cast<int32_t>(t3), |
| static_cast<int32_t>(t0), static_cast<int32_t>(t1), |
| static_cast<int32_t>(t2), static_cast<int32_t>(t3), |
| static_cast<int32_t>(t0), static_cast<int32_t>(t1), |
| static_cast<int32_t>(t2), static_cast<int32_t>(t3), |
| static_cast<int32_t>(t0), static_cast<int32_t>(t1), |
| static_cast<int32_t>(t2), static_cast<int32_t>(t3))}; |
| } |
| |
| template <class D, HWY_IF_F32_D(D), HWY_IF_V_SIZE_D(D, 64)> |
| HWY_API VFromD<D> Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1, |
| TFromD<D> t2, TFromD<D> t3) { |
| return VFromD<D>{_mm512_setr_ps(t0, t1, t2, t3, t0, t1, t2, t3, t0, t1, t2, |
| t3, t0, t1, t2, t3)}; |
| } |
| |
| template <class D, HWY_IF_UI64_D(D), HWY_IF_V_SIZE_D(D, 64)> |
| HWY_API VFromD<D> Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1) { |
| return VFromD<D>{ |
| _mm512_setr_epi64(static_cast<int64_t>(t0), static_cast<int64_t>(t1), |
| static_cast<int64_t>(t0), static_cast<int64_t>(t1), |
| static_cast<int64_t>(t0), static_cast<int64_t>(t1), |
| static_cast<int64_t>(t0), static_cast<int64_t>(t1))}; |
| } |
| |
| template <class D, HWY_IF_F64_D(D), HWY_IF_V_SIZE_D(D, 64)> |
| HWY_API VFromD<D> Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1) { |
| return VFromD<D>{_mm512_setr_pd(t0, t1, t0, t1, t0, t1, t0, t1)}; |
| } |
| |
| // ----------------------------- Iota |
| |
| namespace detail { |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 1)> |
| HWY_INLINE VFromD<D> Iota0(D d) { |
| #if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 900 |
| // Missing set_epi8/16. |
| alignas(64) static constexpr TFromD<D> kIota[64] = { |
| 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, |
| 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, |
| 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, |
| 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}; |
| return Load(d, kIota); |
| #else |
| (void)d; |
| return VFromD<D>{_mm512_set_epi8( |
| static_cast<char>(63), static_cast<char>(62), static_cast<char>(61), |
| static_cast<char>(60), static_cast<char>(59), static_cast<char>(58), |
| static_cast<char>(57), static_cast<char>(56), static_cast<char>(55), |
| static_cast<char>(54), static_cast<char>(53), static_cast<char>(52), |
| static_cast<char>(51), static_cast<char>(50), static_cast<char>(49), |
| static_cast<char>(48), static_cast<char>(47), static_cast<char>(46), |
| static_cast<char>(45), static_cast<char>(44), static_cast<char>(43), |
| static_cast<char>(42), static_cast<char>(41), static_cast<char>(40), |
| static_cast<char>(39), static_cast<char>(38), static_cast<char>(37), |
| static_cast<char>(36), static_cast<char>(35), static_cast<char>(34), |
| static_cast<char>(33), static_cast<char>(32), static_cast<char>(31), |
| static_cast<char>(30), static_cast<char>(29), static_cast<char>(28), |
| static_cast<char>(27), static_cast<char>(26), static_cast<char>(25), |
| static_cast<char>(24), static_cast<char>(23), static_cast<char>(22), |
| static_cast<char>(21), static_cast<char>(20), static_cast<char>(19), |
| static_cast<char>(18), static_cast<char>(17), static_cast<char>(16), |
| static_cast<char>(15), static_cast<char>(14), static_cast<char>(13), |
| static_cast<char>(12), static_cast<char>(11), static_cast<char>(10), |
| static_cast<char>(9), static_cast<char>(8), static_cast<char>(7), |
| static_cast<char>(6), static_cast<char>(5), static_cast<char>(4), |
| static_cast<char>(3), static_cast<char>(2), static_cast<char>(1), |
| static_cast<char>(0))}; |
| #endif |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI16_D(D)> |
| HWY_INLINE VFromD<D> Iota0(D d) { |
| #if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 900 |
| // Missing set_epi8/16. |
| alignas(64) static constexpr TFromD<D> kIota[32] = { |
| 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, |
| 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}; |
| return Load(d, kIota); |
| #else |
| (void)d; |
| return VFromD<D>{_mm512_set_epi16( |
| int16_t{31}, int16_t{30}, int16_t{29}, int16_t{28}, int16_t{27}, |
| int16_t{26}, int16_t{25}, int16_t{24}, int16_t{23}, int16_t{22}, |
| int16_t{21}, int16_t{20}, int16_t{19}, int16_t{18}, int16_t{17}, |
| int16_t{16}, int16_t{15}, int16_t{14}, int16_t{13}, int16_t{12}, |
| int16_t{11}, int16_t{10}, int16_t{9}, int16_t{8}, int16_t{7}, int16_t{6}, |
| int16_t{5}, int16_t{4}, int16_t{3}, int16_t{2}, int16_t{1}, int16_t{0})}; |
| #endif |
| } |
| |
| #if HWY_HAVE_FLOAT16 |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F16_D(D)> |
| HWY_INLINE VFromD<D> Iota0(D /*d*/) { |
| return VFromD<D>{_mm512_set_ph( |
| float16_t{31}, float16_t{30}, float16_t{29}, float16_t{28}, float16_t{27}, |
| float16_t{26}, float16_t{25}, float16_t{24}, float16_t{23}, float16_t{22}, |
| float16_t{21}, float16_t{20}, float16_t{19}, float16_t{18}, float16_t{17}, |
| float16_t{16}, float16_t{15}, float16_t{14}, float16_t{13}, float16_t{12}, |
| float16_t{11}, float16_t{10}, float16_t{9}, float16_t{8}, float16_t{7}, |
| float16_t{6}, float16_t{5}, float16_t{4}, float16_t{3}, float16_t{2}, |
| float16_t{1}, float16_t{0})}; |
| } |
| #endif // HWY_HAVE_FLOAT16 |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI32_D(D)> |
| HWY_INLINE VFromD<D> Iota0(D /*d*/) { |
| return VFromD<D>{_mm512_set_epi32( |
| int32_t{15}, int32_t{14}, int32_t{13}, int32_t{12}, int32_t{11}, |
| int32_t{10}, int32_t{9}, int32_t{8}, int32_t{7}, int32_t{6}, int32_t{5}, |
| int32_t{4}, int32_t{3}, int32_t{2}, int32_t{1}, int32_t{0})}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI64_D(D)> |
| HWY_INLINE VFromD<D> Iota0(D /*d*/) { |
| return VFromD<D>{_mm512_set_epi64(int64_t{7}, int64_t{6}, int64_t{5}, |
| int64_t{4}, int64_t{3}, int64_t{2}, |
| int64_t{1}, int64_t{0})}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> |
| HWY_INLINE VFromD<D> Iota0(D /*d*/) { |
| return VFromD<D>{_mm512_set_ps(15.0f, 14.0f, 13.0f, 12.0f, 11.0f, 10.0f, 9.0f, |
| 8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f, |
| 0.0f)}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> |
| HWY_INLINE VFromD<D> Iota0(D /*d*/) { |
| return VFromD<D>{_mm512_set_pd(7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0)}; |
| } |
| |
| } // namespace detail |
| |
| template <class D, typename T2, HWY_IF_V_SIZE_D(D, 64)> |
| HWY_API VFromD<D> Iota(D d, const T2 first) { |
| return detail::Iota0(d) + Set(d, ConvertScalarTo<TFromD<D>>(first)); |
| } |
| |
| // ================================================== LOGICAL |
| |
| // ------------------------------ Not |
| |
| template <typename T> |
| HWY_API Vec512<T> Not(const Vec512<T> v) { |
| const DFromV<decltype(v)> d; |
| const RebindToUnsigned<decltype(d)> du; |
| using VU = VFromD<decltype(du)>; |
| const __m512i vu = BitCast(du, v).raw; |
| return BitCast(d, VU{_mm512_ternarylogic_epi32(vu, vu, vu, 0x55)}); |
| } |
| |
| // ------------------------------ And |
| |
| template <typename T> |
| HWY_API Vec512<T> And(const Vec512<T> a, const Vec512<T> b) { |
| const DFromV<decltype(a)> d; // for float16_t |
| const RebindToUnsigned<decltype(d)> du; |
| return BitCast(d, VFromD<decltype(du)>{_mm512_and_si512(BitCast(du, a).raw, |
| BitCast(du, b).raw)}); |
| } |
| |
| HWY_API Vec512<float> And(const Vec512<float> a, const Vec512<float> b) { |
| return Vec512<float>{_mm512_and_ps(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<double> And(const Vec512<double> a, const Vec512<double> b) { |
| return Vec512<double>{_mm512_and_pd(a.raw, b.raw)}; |
| } |
| |
| // ------------------------------ AndNot |
| |
| // Returns ~not_mask & mask. |
| template <typename T> |
| HWY_API Vec512<T> AndNot(const Vec512<T> not_mask, const Vec512<T> mask) { |
| const DFromV<decltype(mask)> d; // for float16_t |
| const RebindToUnsigned<decltype(d)> du; |
| return BitCast(d, VFromD<decltype(du)>{_mm512_andnot_si512( |
| BitCast(du, not_mask).raw, BitCast(du, mask).raw)}); |
| } |
| HWY_API Vec512<float> AndNot(const Vec512<float> not_mask, |
| const Vec512<float> mask) { |
| return Vec512<float>{_mm512_andnot_ps(not_mask.raw, mask.raw)}; |
| } |
| HWY_API Vec512<double> AndNot(const Vec512<double> not_mask, |
| const Vec512<double> mask) { |
| return Vec512<double>{_mm512_andnot_pd(not_mask.raw, mask.raw)}; |
| } |
| |
| // ------------------------------ Or |
| |
| template <typename T> |
| HWY_API Vec512<T> Or(const Vec512<T> a, const Vec512<T> b) { |
| const DFromV<decltype(a)> d; // for float16_t |
| const RebindToUnsigned<decltype(d)> du; |
| return BitCast(d, VFromD<decltype(du)>{_mm512_or_si512(BitCast(du, a).raw, |
| BitCast(du, b).raw)}); |
| } |
| |
| HWY_API Vec512<float> Or(const Vec512<float> a, const Vec512<float> b) { |
| return Vec512<float>{_mm512_or_ps(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<double> Or(const Vec512<double> a, const Vec512<double> b) { |
| return Vec512<double>{_mm512_or_pd(a.raw, b.raw)}; |
| } |
| |
| // ------------------------------ Xor |
| |
| template <typename T> |
| HWY_API Vec512<T> Xor(const Vec512<T> a, const Vec512<T> b) { |
| const DFromV<decltype(a)> d; // for float16_t |
| const RebindToUnsigned<decltype(d)> du; |
| return BitCast(d, VFromD<decltype(du)>{_mm512_xor_si512(BitCast(du, a).raw, |
| BitCast(du, b).raw)}); |
| } |
| |
| HWY_API Vec512<float> Xor(const Vec512<float> a, const Vec512<float> b) { |
| return Vec512<float>{_mm512_xor_ps(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<double> Xor(const Vec512<double> a, const Vec512<double> b) { |
| return Vec512<double>{_mm512_xor_pd(a.raw, b.raw)}; |
| } |
| |
| // ------------------------------ Xor3 |
| template <typename T> |
| HWY_API Vec512<T> Xor3(Vec512<T> x1, Vec512<T> x2, Vec512<T> x3) { |
| #if !HWY_IS_MSAN |
| const DFromV<decltype(x1)> d; |
| const RebindToUnsigned<decltype(d)> du; |
| using VU = VFromD<decltype(du)>; |
| const __m512i ret = _mm512_ternarylogic_epi64( |
| BitCast(du, x1).raw, BitCast(du, x2).raw, BitCast(du, x3).raw, 0x96); |
| return BitCast(d, VU{ret}); |
| #else |
| return Xor(x1, Xor(x2, x3)); |
| #endif |
| } |
| |
| // ------------------------------ Or3 |
| template <typename T> |
| HWY_API Vec512<T> Or3(Vec512<T> o1, Vec512<T> o2, Vec512<T> o3) { |
| #if !HWY_IS_MSAN |
| const DFromV<decltype(o1)> d; |
| const RebindToUnsigned<decltype(d)> du; |
| using VU = VFromD<decltype(du)>; |
| const __m512i ret = _mm512_ternarylogic_epi64( |
| BitCast(du, o1).raw, BitCast(du, o2).raw, BitCast(du, o3).raw, 0xFE); |
| return BitCast(d, VU{ret}); |
| #else |
| return Or(o1, Or(o2, o3)); |
| #endif |
| } |
| |
| // ------------------------------ OrAnd |
| template <typename T> |
| HWY_API Vec512<T> OrAnd(Vec512<T> o, Vec512<T> a1, Vec512<T> a2) { |
| #if !HWY_IS_MSAN |
| const DFromV<decltype(o)> d; |
| const RebindToUnsigned<decltype(d)> du; |
| using VU = VFromD<decltype(du)>; |
| const __m512i ret = _mm512_ternarylogic_epi64( |
| BitCast(du, o).raw, BitCast(du, a1).raw, BitCast(du, a2).raw, 0xF8); |
| return BitCast(d, VU{ret}); |
| #else |
| return Or(o, And(a1, a2)); |
| #endif |
| } |
| |
| // ------------------------------ IfVecThenElse |
| template <typename T> |
| HWY_API Vec512<T> IfVecThenElse(Vec512<T> mask, Vec512<T> yes, Vec512<T> no) { |
| #if !HWY_IS_MSAN |
| const DFromV<decltype(yes)> d; |
| const RebindToUnsigned<decltype(d)> du; |
| using VU = VFromD<decltype(du)>; |
| return BitCast(d, VU{_mm512_ternarylogic_epi64(BitCast(du, mask).raw, |
| BitCast(du, yes).raw, |
| BitCast(du, no).raw, 0xCA)}); |
| #else |
| return IfThenElse(MaskFromVec(mask), yes, no); |
| #endif |
| } |
| |
| // ------------------------------ Operator overloads (internal-only if float) |
| |
| template <typename T> |
| HWY_API Vec512<T> operator&(const Vec512<T> a, const Vec512<T> b) { |
| return And(a, b); |
| } |
| |
| template <typename T> |
| HWY_API Vec512<T> operator|(const Vec512<T> a, const Vec512<T> b) { |
| return Or(a, b); |
| } |
| |
| template <typename T> |
| HWY_API Vec512<T> operator^(const Vec512<T> a, const Vec512<T> b) { |
| return Xor(a, b); |
| } |
| |
| // ------------------------------ PopulationCount |
| |
| // 8/16 require BITALG, 32/64 require VPOPCNTDQ. |
| #if HWY_TARGET <= HWY_AVX3_DL |
| |
| #ifdef HWY_NATIVE_POPCNT |
| #undef HWY_NATIVE_POPCNT |
| #else |
| #define HWY_NATIVE_POPCNT |
| #endif |
| |
| namespace detail { |
| |
| template <typename T> |
| HWY_INLINE Vec512<T> PopulationCount(hwy::SizeTag<1> /* tag */, Vec512<T> v) { |
| return Vec512<T>{_mm512_popcnt_epi8(v.raw)}; |
| } |
| template <typename T> |
| HWY_INLINE Vec512<T> PopulationCount(hwy::SizeTag<2> /* tag */, Vec512<T> v) { |
| return Vec512<T>{_mm512_popcnt_epi16(v.raw)}; |
| } |
| template <typename T> |
| HWY_INLINE Vec512<T> PopulationCount(hwy::SizeTag<4> /* tag */, Vec512<T> v) { |
| return Vec512<T>{_mm512_popcnt_epi32(v.raw)}; |
| } |
| template <typename T> |
| HWY_INLINE Vec512<T> PopulationCount(hwy::SizeTag<8> /* tag */, Vec512<T> v) { |
| return Vec512<T>{_mm512_popcnt_epi64(v.raw)}; |
| } |
| |
| } // namespace detail |
| |
| template <typename T> |
| HWY_API Vec512<T> PopulationCount(Vec512<T> v) { |
| return detail::PopulationCount(hwy::SizeTag<sizeof(T)>(), v); |
| } |
| |
| #endif // HWY_TARGET <= HWY_AVX3_DL |
| |
| // ================================================== MASK |
| |
| // ------------------------------ FirstN |
| |
| // Possibilities for constructing a bitmask of N ones: |
| // - kshift* only consider the lowest byte of the shift count, so they would |
| // not correctly handle large n. |
| // - Scalar shifts >= 64 are UB. |
| // - BZHI has the desired semantics; we assume AVX-512 implies BMI2. However, |
| // we need 64-bit masks for sizeof(T) == 1, so special-case 32-bit builds. |
| |
| #if HWY_ARCH_X86_32 |
| namespace detail { |
| |
| // 32 bit mask is sufficient for lane size >= 2. |
| template <typename T, HWY_IF_NOT_T_SIZE(T, 1)> |
| HWY_INLINE Mask512<T> FirstN(size_t n) { |
| Mask512<T> m; |
| const uint32_t all = ~uint32_t{0}; |
| // BZHI only looks at the lower 8 bits of n, but it has been clamped to |
| // MaxLanes, which is at most 32. |
| m.raw = static_cast<decltype(m.raw)>(_bzhi_u32(all, n)); |
| return m; |
| } |
| |
| #if HWY_COMPILER_MSVC >= 1920 || HWY_COMPILER_GCC_ACTUAL >= 900 || \ |
| HWY_COMPILER_CLANG || HWY_COMPILER_ICC |
| template <typename T, HWY_IF_T_SIZE(T, 1)> |
| HWY_INLINE Mask512<T> FirstN(size_t n) { |
| uint32_t lo_mask; |
| uint32_t hi_mask; |
| uint32_t hi_mask_len; |
| #if HWY_COMPILER_GCC |
| if (__builtin_constant_p(n >= 32) && n >= 32) { |
| if (__builtin_constant_p(n >= 64) && n >= 64) { |
| hi_mask_len = 32u; |
| } else { |
| hi_mask_len = static_cast<uint32_t>(n) - 32u; |
| } |
| lo_mask = hi_mask = 0xFFFFFFFFu; |
| } else // NOLINT(readability/braces) |
| #endif |
| { |
| const uint32_t lo_mask_len = static_cast<uint32_t>(n); |
| lo_mask = _bzhi_u32(0xFFFFFFFFu, lo_mask_len); |
| |
| #if HWY_COMPILER_GCC |
| if (__builtin_constant_p(lo_mask_len <= 32) && lo_mask_len <= 32) { |
| return Mask512<T>{static_cast<__mmask64>(lo_mask)}; |
| } |
| #endif |
| |
| _addcarry_u32(_subborrow_u32(0, lo_mask_len, 32u, &hi_mask_len), |
| 0xFFFFFFFFu, 0u, &hi_mask); |
| } |
| hi_mask = _bzhi_u32(hi_mask, hi_mask_len); |
| #if HWY_COMPILER_GCC && !HWY_COMPILER_ICC |
| if (__builtin_constant_p((static_cast<uint64_t>(hi_mask) << 32) | lo_mask)) |
| #endif |
| return Mask512<T>{static_cast<__mmask64>( |
| (static_cast<uint64_t>(hi_mask) << 32) | lo_mask)}; |
| #if HWY_COMPILER_GCC && !HWY_COMPILER_ICC |
| else |
| return Mask512<T>{_mm512_kunpackd(static_cast<__mmask64>(hi_mask), |
| static_cast<__mmask64>(lo_mask))}; |
| #endif |
| } |
| #else // HWY_COMPILER.. |
| template <typename T, HWY_IF_T_SIZE(T, 1)> |
| HWY_INLINE Mask512<T> FirstN(size_t n) { |
| const uint64_t bits = n < 64 ? ((1ULL << n) - 1) : ~uint64_t{0}; |
| return Mask512<T>{static_cast<__mmask64>(bits)}; |
| } |
| #endif // HWY_COMPILER.. |
| } // namespace detail |
| #endif // HWY_ARCH_X86_32 |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64)> |
| HWY_API MFromD<D> FirstN(D d, size_t n) { |
| // This ensures `num` <= 255 as required by bzhi, which only looks |
| // at the lower 8 bits. |
| n = HWY_MIN(n, MaxLanes(d)); |
| |
| #if HWY_ARCH_X86_64 |
| MFromD<D> m; |
| const uint64_t all = ~uint64_t{0}; |
| m.raw = static_cast<decltype(m.raw)>(_bzhi_u64(all, n)); |
| return m; |
| #else |
| return detail::FirstN<TFromD<D>>(n); |
| #endif // HWY_ARCH_X86_64 |
| } |
| |
| // ------------------------------ IfThenElse |
| |
| // Returns mask ? b : a. |
| |
| namespace detail { |
| |
| // Templates for signed/unsigned integer of a particular size. |
| template <typename T> |
| HWY_INLINE Vec512<T> IfThenElse(hwy::SizeTag<1> /* tag */, |
| const Mask512<T> mask, const Vec512<T> yes, |
| const Vec512<T> no) { |
| return Vec512<T>{_mm512_mask_blend_epi8(mask.raw, no.raw, yes.raw)}; |
| } |
| template <typename T> |
| HWY_INLINE Vec512<T> IfThenElse(hwy::SizeTag<2> /* tag */, |
| const Mask512<T> mask, const Vec512<T> yes, |
| const Vec512<T> no) { |
| return Vec512<T>{_mm512_mask_blend_epi16(mask.raw, no.raw, yes.raw)}; |
| } |
| template <typename T> |
| HWY_INLINE Vec512<T> IfThenElse(hwy::SizeTag<4> /* tag */, |
| const Mask512<T> mask, const Vec512<T> yes, |
| const Vec512<T> no) { |
| return Vec512<T>{_mm512_mask_blend_epi32(mask.raw, no.raw, yes.raw)}; |
| } |
| template <typename T> |
| HWY_INLINE Vec512<T> IfThenElse(hwy::SizeTag<8> /* tag */, |
| const Mask512<T> mask, const Vec512<T> yes, |
| const Vec512<T> no) { |
| return Vec512<T>{_mm512_mask_blend_epi64(mask.raw, no.raw, yes.raw)}; |
| } |
| |
| } // namespace detail |
| |
| template <typename T, HWY_IF_NOT_FLOAT_NOR_SPECIAL(T)> |
| HWY_API Vec512<T> IfThenElse(const Mask512<T> mask, const Vec512<T> yes, |
| const Vec512<T> no) { |
| return detail::IfThenElse(hwy::SizeTag<sizeof(T)>(), mask, yes, no); |
| } |
| #if HWY_HAVE_FLOAT16 |
| HWY_API Vec512<float16_t> IfThenElse(Mask512<float16_t> mask, |
| Vec512<float16_t> yes, |
| Vec512<float16_t> no) { |
| return Vec512<float16_t>{_mm512_mask_blend_ph(mask.raw, no.raw, yes.raw)}; |
| } |
| #endif // HWY_HAVE_FLOAT16 |
| HWY_API Vec512<float> IfThenElse(Mask512<float> mask, Vec512<float> yes, |
| Vec512<float> no) { |
| return Vec512<float>{_mm512_mask_blend_ps(mask.raw, no.raw, yes.raw)}; |
| } |
| HWY_API Vec512<double> IfThenElse(Mask512<double> mask, Vec512<double> yes, |
| Vec512<double> no) { |
| return Vec512<double>{_mm512_mask_blend_pd(mask.raw, no.raw, yes.raw)}; |
| } |
| |
| namespace detail { |
| |
| template <typename T> |
| HWY_INLINE Vec512<T> IfThenElseZero(hwy::SizeTag<1> /* tag */, |
| const Mask512<T> mask, |
| const Vec512<T> yes) { |
| return Vec512<T>{_mm512_maskz_mov_epi8(mask.raw, yes.raw)}; |
| } |
| template <typename T> |
| HWY_INLINE Vec512<T> IfThenElseZero(hwy::SizeTag<2> /* tag */, |
| const Mask512<T> mask, |
| const Vec512<T> yes) { |
| return Vec512<T>{_mm512_maskz_mov_epi16(mask.raw, yes.raw)}; |
| } |
| template <typename T> |
| HWY_INLINE Vec512<T> IfThenElseZero(hwy::SizeTag<4> /* tag */, |
| const Mask512<T> mask, |
| const Vec512<T> yes) { |
| return Vec512<T>{_mm512_maskz_mov_epi32(mask.raw, yes.raw)}; |
| } |
| template <typename T> |
| HWY_INLINE Vec512<T> IfThenElseZero(hwy::SizeTag<8> /* tag */, |
| const Mask512<T> mask, |
| const Vec512<T> yes) { |
| return Vec512<T>{_mm512_maskz_mov_epi64(mask.raw, yes.raw)}; |
| } |
| |
| } // namespace detail |
| |
| template <typename T, HWY_IF_NOT_FLOAT_NOR_SPECIAL(T)> |
| HWY_API Vec512<T> IfThenElseZero(const Mask512<T> mask, const Vec512<T> yes) { |
| return detail::IfThenElseZero(hwy::SizeTag<sizeof(T)>(), mask, yes); |
| } |
| HWY_API Vec512<float> IfThenElseZero(Mask512<float> mask, Vec512<float> yes) { |
| return Vec512<float>{_mm512_maskz_mov_ps(mask.raw, yes.raw)}; |
| } |
| HWY_API Vec512<double> IfThenElseZero(Mask512<double> mask, |
| Vec512<double> yes) { |
| return Vec512<double>{_mm512_maskz_mov_pd(mask.raw, yes.raw)}; |
| } |
| |
| namespace detail { |
| |
| template <typename T> |
| HWY_INLINE Vec512<T> IfThenZeroElse(hwy::SizeTag<1> /* tag */, |
| const Mask512<T> mask, const Vec512<T> no) { |
| // xor_epi8/16 are missing, but we have sub, which is just as fast for u8/16. |
| return Vec512<T>{_mm512_mask_sub_epi8(no.raw, mask.raw, no.raw, no.raw)}; |
| } |
| template <typename T> |
| HWY_INLINE Vec512<T> IfThenZeroElse(hwy::SizeTag<2> /* tag */, |
| const Mask512<T> mask, const Vec512<T> no) { |
| return Vec512<T>{_mm512_mask_sub_epi16(no.raw, mask.raw, no.raw, no.raw)}; |
| } |
| template <typename T> |
| HWY_INLINE Vec512<T> IfThenZeroElse(hwy::SizeTag<4> /* tag */, |
| const Mask512<T> mask, const Vec512<T> no) { |
| return Vec512<T>{_mm512_mask_xor_epi32(no.raw, mask.raw, no.raw, no.raw)}; |
| } |
| template <typename T> |
| HWY_INLINE Vec512<T> IfThenZeroElse(hwy::SizeTag<8> /* tag */, |
| const Mask512<T> mask, const Vec512<T> no) { |
| return Vec512<T>{_mm512_mask_xor_epi64(no.raw, mask.raw, no.raw, no.raw)}; |
| } |
| |
| } // namespace detail |
| |
| template <typename T, HWY_IF_NOT_FLOAT_NOR_SPECIAL(T)> |
| HWY_API Vec512<T> IfThenZeroElse(const Mask512<T> mask, const Vec512<T> no) { |
| return detail::IfThenZeroElse(hwy::SizeTag<sizeof(T)>(), mask, no); |
| } |
| HWY_API Vec512<float> IfThenZeroElse(Mask512<float> mask, Vec512<float> no) { |
| return Vec512<float>{_mm512_mask_xor_ps(no.raw, mask.raw, no.raw, no.raw)}; |
| } |
| HWY_API Vec512<double> IfThenZeroElse(Mask512<double> mask, Vec512<double> no) { |
| return Vec512<double>{_mm512_mask_xor_pd(no.raw, mask.raw, no.raw, no.raw)}; |
| } |
| |
| template <typename T> |
| HWY_API Vec512<T> IfNegativeThenElse(Vec512<T> v, Vec512<T> yes, Vec512<T> no) { |
| static_assert(IsSigned<T>(), "Only works for signed/float"); |
| // AVX3 MaskFromVec only looks at the MSB |
| return IfThenElse(MaskFromVec(v), yes, no); |
| } |
| |
| template <typename T, HWY_IF_NOT_FLOAT_NOR_SPECIAL(T), |
| HWY_IF_T_SIZE_ONE_OF(T, (1 << 1) | (1 << 2) | (1 << 4))> |
| HWY_API Vec512<T> IfNegativeThenNegOrUndefIfZero(Vec512<T> mask, Vec512<T> v) { |
| // AVX3 MaskFromVec only looks at the MSB |
| const DFromV<decltype(v)> d; |
| return MaskedSubOr(v, MaskFromVec(mask), Zero(d), v); |
| } |
| |
| // ================================================== ARITHMETIC |
| |
| // ------------------------------ Addition |
| |
| // Unsigned |
| HWY_API Vec512<uint8_t> operator+(Vec512<uint8_t> a, Vec512<uint8_t> b) { |
| return Vec512<uint8_t>{_mm512_add_epi8(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<uint16_t> operator+(Vec512<uint16_t> a, Vec512<uint16_t> b) { |
| return Vec512<uint16_t>{_mm512_add_epi16(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<uint32_t> operator+(Vec512<uint32_t> a, Vec512<uint32_t> b) { |
| return Vec512<uint32_t>{_mm512_add_epi32(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<uint64_t> operator+(Vec512<uint64_t> a, Vec512<uint64_t> b) { |
| return Vec512<uint64_t>{_mm512_add_epi64(a.raw, b.raw)}; |
| } |
| |
| // Signed |
| HWY_API Vec512<int8_t> operator+(Vec512<int8_t> a, Vec512<int8_t> b) { |
| return Vec512<int8_t>{_mm512_add_epi8(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<int16_t> operator+(Vec512<int16_t> a, Vec512<int16_t> b) { |
| return Vec512<int16_t>{_mm512_add_epi16(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<int32_t> operator+(Vec512<int32_t> a, Vec512<int32_t> b) { |
| return Vec512<int32_t>{_mm512_add_epi32(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<int64_t> operator+(Vec512<int64_t> a, Vec512<int64_t> b) { |
| return Vec512<int64_t>{_mm512_add_epi64(a.raw, b.raw)}; |
| } |
| |
| // Float |
| #if HWY_HAVE_FLOAT16 |
| HWY_API Vec512<float16_t> operator+(Vec512<float16_t> a, Vec512<float16_t> b) { |
| return Vec512<float16_t>{_mm512_add_ph(a.raw, b.raw)}; |
| } |
| #endif // HWY_HAVE_FLOAT16 |
| HWY_API Vec512<float> operator+(Vec512<float> a, Vec512<float> b) { |
| return Vec512<float>{_mm512_add_ps(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<double> operator+(Vec512<double> a, Vec512<double> b) { |
| return Vec512<double>{_mm512_add_pd(a.raw, b.raw)}; |
| } |
| |
| // ------------------------------ Subtraction |
| |
| // Unsigned |
| HWY_API Vec512<uint8_t> operator-(Vec512<uint8_t> a, Vec512<uint8_t> b) { |
| return Vec512<uint8_t>{_mm512_sub_epi8(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<uint16_t> operator-(Vec512<uint16_t> a, Vec512<uint16_t> b) { |
| return Vec512<uint16_t>{_mm512_sub_epi16(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<uint32_t> operator-(Vec512<uint32_t> a, Vec512<uint32_t> b) { |
| return Vec512<uint32_t>{_mm512_sub_epi32(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<uint64_t> operator-(Vec512<uint64_t> a, Vec512<uint64_t> b) { |
| return Vec512<uint64_t>{_mm512_sub_epi64(a.raw, b.raw)}; |
| } |
| |
| // Signed |
| HWY_API Vec512<int8_t> operator-(Vec512<int8_t> a, Vec512<int8_t> b) { |
| return Vec512<int8_t>{_mm512_sub_epi8(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<int16_t> operator-(Vec512<int16_t> a, Vec512<int16_t> b) { |
| return Vec512<int16_t>{_mm512_sub_epi16(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<int32_t> operator-(Vec512<int32_t> a, Vec512<int32_t> b) { |
| return Vec512<int32_t>{_mm512_sub_epi32(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<int64_t> operator-(Vec512<int64_t> a, Vec512<int64_t> b) { |
| return Vec512<int64_t>{_mm512_sub_epi64(a.raw, b.raw)}; |
| } |
| |
| // Float |
| #if HWY_HAVE_FLOAT16 |
| HWY_API Vec512<float16_t> operator-(Vec512<float16_t> a, Vec512<float16_t> b) { |
| return Vec512<float16_t>{_mm512_sub_ph(a.raw, b.raw)}; |
| } |
| #endif // HWY_HAVE_FLOAT16 |
| HWY_API Vec512<float> operator-(Vec512<float> a, Vec512<float> b) { |
| return Vec512<float>{_mm512_sub_ps(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<double> operator-(Vec512<double> a, Vec512<double> b) { |
| return Vec512<double>{_mm512_sub_pd(a.raw, b.raw)}; |
| } |
| |
| // ------------------------------ SumsOf8 |
| HWY_API Vec512<uint64_t> SumsOf8(const Vec512<uint8_t> v) { |
| const Full512<uint8_t> d; |
| return Vec512<uint64_t>{_mm512_sad_epu8(v.raw, Zero(d).raw)}; |
| } |
| |
| HWY_API Vec512<uint64_t> SumsOf8AbsDiff(Vec512<uint8_t> a, Vec512<uint8_t> b) { |
| return Vec512<uint64_t>{_mm512_sad_epu8(a.raw, b.raw)}; |
| } |
| |
| // ------------------------------ SumsOf4 |
| namespace detail { |
| |
| HWY_INLINE Vec512<uint32_t> SumsOf4(hwy::UnsignedTag /*type_tag*/, |
| hwy::SizeTag<1> /*lane_size_tag*/, |
| Vec512<uint8_t> v) { |
| const DFromV<decltype(v)> d; |
| |
| // _mm512_maskz_dbsad_epu8 is used below as the odd uint16_t lanes need to be |
| // zeroed out and the sums of the 4 consecutive lanes are already in the |
| // even uint16_t lanes of the _mm512_maskz_dbsad_epu8 result. |
| return Vec512<uint32_t>{_mm512_maskz_dbsad_epu8( |
| static_cast<__mmask32>(0x55555555), v.raw, Zero(d).raw, 0)}; |
| } |
| |
| // I8->I32 SumsOf4 |
| // Generic for all vector lengths |
| template <class V> |
| HWY_INLINE VFromD<RepartitionToWideX2<DFromV<V>>> SumsOf4( |
| hwy::SignedTag /*type_tag*/, hwy::SizeTag<1> /*lane_size_tag*/, V v) { |
| const DFromV<decltype(v)> d; |
| const RebindToUnsigned<decltype(d)> du; |
| const RepartitionToWideX2<decltype(d)> di32; |
| |
| // Adjust the values of v to be in the 0..255 range by adding 128 to each lane |
| // of v (which is the same as an bitwise XOR of each i8 lane by 128) and then |
| // bitcasting the Xor result to an u8 vector. |
| const auto v_adj = BitCast(du, Xor(v, SignBit(d))); |
| |
| // Need to add -512 to each i32 lane of the result of the |
| // SumsOf4(hwy::UnsignedTag(), hwy::SizeTag<1>(), v_adj) operation to account |
| // for the adjustment made above. |
| return BitCast(di32, SumsOf4(hwy::UnsignedTag(), hwy::SizeTag<1>(), v_adj)) + |
| Set(di32, int32_t{-512}); |
| } |
| |
| } // namespace detail |
| |
| // ------------------------------ SumsOfShuffledQuadAbsDiff |
| |
| #if HWY_TARGET <= HWY_AVX3 |
| template <int kIdx3, int kIdx2, int kIdx1, int kIdx0> |
| static Vec512<uint16_t> SumsOfShuffledQuadAbsDiff(Vec512<uint8_t> a, |
| Vec512<uint8_t> b) { |
| static_assert(0 <= kIdx0 && kIdx0 <= 3, "kIdx0 must be between 0 and 3"); |
| static_assert(0 <= kIdx1 && kIdx1 <= 3, "kIdx1 must be between 0 and 3"); |
| static_assert(0 <= kIdx2 && kIdx2 <= 3, "kIdx2 must be between 0 and 3"); |
| static_assert(0 <= kIdx3 && kIdx3 <= 3, "kIdx3 must be between 0 and 3"); |
| return Vec512<uint16_t>{ |
| _mm512_dbsad_epu8(b.raw, a.raw, _MM_SHUFFLE(kIdx3, kIdx2, kIdx1, kIdx0))}; |
| } |
| #endif |
| |
| // ------------------------------ SaturatedAdd |
| |
| // Returns a + b clamped to the destination range. |
| |
| // Unsigned |
| HWY_API Vec512<uint8_t> SaturatedAdd(Vec512<uint8_t> a, Vec512<uint8_t> b) { |
| return Vec512<uint8_t>{_mm512_adds_epu8(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<uint16_t> SaturatedAdd(Vec512<uint16_t> a, Vec512<uint16_t> b) { |
| return Vec512<uint16_t>{_mm512_adds_epu16(a.raw, b.raw)}; |
| } |
| |
| // Signed |
| HWY_API Vec512<int8_t> SaturatedAdd(Vec512<int8_t> a, Vec512<int8_t> b) { |
| return Vec512<int8_t>{_mm512_adds_epi8(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<int16_t> SaturatedAdd(Vec512<int16_t> a, Vec512<int16_t> b) { |
| return Vec512<int16_t>{_mm512_adds_epi16(a.raw, b.raw)}; |
| } |
| |
| // ------------------------------ SaturatedSub |
| |
| // Returns a - b clamped to the destination range. |
| |
| // Unsigned |
| HWY_API Vec512<uint8_t> SaturatedSub(Vec512<uint8_t> a, Vec512<uint8_t> b) { |
| return Vec512<uint8_t>{_mm512_subs_epu8(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<uint16_t> SaturatedSub(Vec512<uint16_t> a, Vec512<uint16_t> b) { |
| return Vec512<uint16_t>{_mm512_subs_epu16(a.raw, b.raw)}; |
| } |
| |
| // Signed |
| HWY_API Vec512<int8_t> SaturatedSub(Vec512<int8_t> a, Vec512<int8_t> b) { |
| return Vec512<int8_t>{_mm512_subs_epi8(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<int16_t> SaturatedSub(Vec512<int16_t> a, Vec512<int16_t> b) { |
| return Vec512<int16_t>{_mm512_subs_epi16(a.raw, b.raw)}; |
| } |
| |
| // ------------------------------ Average |
| |
| // Returns (a + b + 1) / 2 |
| |
| // Unsigned |
| HWY_API Vec512<uint8_t> AverageRound(Vec512<uint8_t> a, Vec512<uint8_t> b) { |
| return Vec512<uint8_t>{_mm512_avg_epu8(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<uint16_t> AverageRound(Vec512<uint16_t> a, Vec512<uint16_t> b) { |
| return Vec512<uint16_t>{_mm512_avg_epu16(a.raw, b.raw)}; |
| } |
| |
| // ------------------------------ Abs (Sub) |
| |
| // Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. |
| HWY_API Vec512<int8_t> Abs(const Vec512<int8_t> v) { |
| #if HWY_COMPILER_MSVC |
| // Workaround for incorrect codegen? (untested due to internal compiler error) |
| const DFromV<decltype(v)> d; |
| const auto zero = Zero(d); |
| return Vec512<int8_t>{_mm512_max_epi8(v.raw, (zero - v).raw)}; |
| #else |
| return Vec512<int8_t>{_mm512_abs_epi8(v.raw)}; |
| #endif |
| } |
| HWY_API Vec512<int16_t> Abs(const Vec512<int16_t> v) { |
| return Vec512<int16_t>{_mm512_abs_epi16(v.raw)}; |
| } |
| HWY_API Vec512<int32_t> Abs(const Vec512<int32_t> v) { |
| return Vec512<int32_t>{_mm512_abs_epi32(v.raw)}; |
| } |
| HWY_API Vec512<int64_t> Abs(const Vec512<int64_t> v) { |
| return Vec512<int64_t>{_mm512_abs_epi64(v.raw)}; |
| } |
| |
| // ------------------------------ ShiftLeft |
| |
| #if HWY_TARGET <= HWY_AVX3_DL |
| namespace detail { |
| template <typename T> |
| HWY_API Vec512<T> GaloisAffine(Vec512<T> v, Vec512<uint64_t> matrix) { |
| return Vec512<T>{_mm512_gf2p8affine_epi64_epi8(v.raw, matrix.raw, 0)}; |
| } |
| } // namespace detail |
| #endif // HWY_TARGET <= HWY_AVX3_DL |
| |
| template <int kBits> |
| HWY_API Vec512<uint16_t> ShiftLeft(const Vec512<uint16_t> v) { |
| return Vec512<uint16_t>{_mm512_slli_epi16(v.raw, kBits)}; |
| } |
| |
| template <int kBits> |
| HWY_API Vec512<uint32_t> ShiftLeft(const Vec512<uint32_t> v) { |
| return Vec512<uint32_t>{_mm512_slli_epi32(v.raw, kBits)}; |
| } |
| |
| template <int kBits> |
| HWY_API Vec512<uint64_t> ShiftLeft(const Vec512<uint64_t> v) { |
| return Vec512<uint64_t>{_mm512_slli_epi64(v.raw, kBits)}; |
| } |
| |
| template <int kBits> |
| HWY_API Vec512<int16_t> ShiftLeft(const Vec512<int16_t> v) { |
| return Vec512<int16_t>{_mm512_slli_epi16(v.raw, kBits)}; |
| } |
| |
| template <int kBits> |
| HWY_API Vec512<int32_t> ShiftLeft(const Vec512<int32_t> v) { |
| return Vec512<int32_t>{_mm512_slli_epi32(v.raw, kBits)}; |
| } |
| |
| template <int kBits> |
| HWY_API Vec512<int64_t> ShiftLeft(const Vec512<int64_t> v) { |
| return Vec512<int64_t>{_mm512_slli_epi64(v.raw, kBits)}; |
| } |
| |
| #if HWY_TARGET > HWY_AVX3_DL |
| |
| template <int kBits, typename T, HWY_IF_T_SIZE(T, 1)> |
| HWY_API Vec512<T> ShiftLeft(const Vec512<T> v) { |
| const DFromV<decltype(v)> d8; |
| const RepartitionToWide<decltype(d8)> d16; |
| const auto shifted = BitCast(d8, ShiftLeft<kBits>(BitCast(d16, v))); |
| return kBits == 1 |
| ? (v + v) |
| : (shifted & Set(d8, static_cast<T>((0xFF << kBits) & 0xFF))); |
| } |
| |
| #endif // HWY_TARGET > HWY_AVX3_DL |
| |
| // ------------------------------ ShiftRight |
| |
| template <int kBits> |
| HWY_API Vec512<uint16_t> ShiftRight(const Vec512<uint16_t> v) { |
| return Vec512<uint16_t>{_mm512_srli_epi16(v.raw, kBits)}; |
| } |
| |
| template <int kBits> |
| HWY_API Vec512<uint32_t> ShiftRight(const Vec512<uint32_t> v) { |
| return Vec512<uint32_t>{_mm512_srli_epi32(v.raw, kBits)}; |
| } |
| |
| template <int kBits> |
| HWY_API Vec512<uint64_t> ShiftRight(const Vec512<uint64_t> v) { |
| return Vec512<uint64_t>{_mm512_srli_epi64(v.raw, kBits)}; |
| } |
| |
| template <int kBits> |
| HWY_API Vec512<int16_t> ShiftRight(const Vec512<int16_t> v) { |
| return Vec512<int16_t>{_mm512_srai_epi16(v.raw, kBits)}; |
| } |
| |
| template <int kBits> |
| HWY_API Vec512<int32_t> ShiftRight(const Vec512<int32_t> v) { |
| return Vec512<int32_t>{_mm512_srai_epi32(v.raw, kBits)}; |
| } |
| |
| template <int kBits> |
| HWY_API Vec512<int64_t> ShiftRight(const Vec512<int64_t> v) { |
| return Vec512<int64_t>{_mm512_srai_epi64(v.raw, kBits)}; |
| } |
| |
| #if HWY_TARGET > HWY_AVX3_DL |
| |
| template <int kBits> |
| HWY_API Vec512<uint8_t> ShiftRight(const Vec512<uint8_t> v) { |
| const DFromV<decltype(v)> d8; |
| // Use raw instead of BitCast to support N=1. |
| const Vec512<uint8_t> shifted{ShiftRight<kBits>(Vec512<uint16_t>{v.raw}).raw}; |
| return shifted & Set(d8, 0xFF >> kBits); |
| } |
| |
| template <int kBits> |
| HWY_API Vec512<int8_t> ShiftRight(const Vec512<int8_t> v) { |
| const DFromV<decltype(v)> di; |
| const RebindToUnsigned<decltype(di)> du; |
| const auto shifted = BitCast(di, ShiftRight<kBits>(BitCast(du, v))); |
| const auto shifted_sign = BitCast(di, Set(du, 0x80 >> kBits)); |
| return (shifted ^ shifted_sign) - shifted_sign; |
| } |
| |
| #endif // HWY_TARGET > HWY_AVX3_DL |
| |
| // ------------------------------ RotateRight |
| |
| #if HWY_TARGET > HWY_AVX3_DL |
| template <int kBits> |
| HWY_API Vec512<uint8_t> RotateRight(const Vec512<uint8_t> v) { |
| static_assert(0 <= kBits && kBits < 8, "Invalid shift count"); |
| if (kBits == 0) return v; |
| // AVX3 does not support 8-bit. |
| return Or(ShiftRight<kBits>(v), ShiftLeft<HWY_MIN(7, 8 - kBits)>(v)); |
| } |
| #endif // HWY_TARGET > HWY_AVX3_DL |
| |
| template <int kBits> |
| HWY_API Vec512<uint16_t> RotateRight(const Vec512<uint16_t> v) { |
| static_assert(0 <= kBits && kBits < 16, "Invalid shift count"); |
| if (kBits == 0) return v; |
| #if HWY_TARGET <= HWY_AVX3_DL |
| return Vec512<uint16_t>{_mm512_shrdi_epi16(v.raw, v.raw, kBits)}; |
| #else |
| // AVX3 does not support 16-bit. |
| return Or(ShiftRight<kBits>(v), ShiftLeft<HWY_MIN(15, 16 - kBits)>(v)); |
| #endif |
| } |
| |
| template <int kBits> |
| HWY_API Vec512<uint32_t> RotateRight(const Vec512<uint32_t> v) { |
| static_assert(0 <= kBits && kBits < 32, "Invalid shift count"); |
| if (kBits == 0) return v; |
| return Vec512<uint32_t>{_mm512_ror_epi32(v.raw, kBits)}; |
| } |
| |
| template <int kBits> |
| HWY_API Vec512<uint64_t> RotateRight(const Vec512<uint64_t> v) { |
| static_assert(0 <= kBits && kBits < 64, "Invalid shift count"); |
| if (kBits == 0) return v; |
| return Vec512<uint64_t>{_mm512_ror_epi64(v.raw, kBits)}; |
| } |
| |
| // ------------------------------ Rol/Ror |
| #if HWY_TARGET <= HWY_AVX3_DL |
| template <class T, HWY_IF_UI16(T)> |
| HWY_API Vec512<T> Ror(Vec512<T> a, Vec512<T> b) { |
| return Vec512<T>{_mm512_shrdv_epi16(a.raw, a.raw, b.raw)}; |
| } |
| #endif // HWY_TARGET <= HWY_AVX3_DL |
| |
| template <class T, HWY_IF_UI32(T)> |
| HWY_API Vec512<T> Rol(Vec512<T> a, Vec512<T> b) { |
| return Vec512<T>{_mm512_rolv_epi32(a.raw, b.raw)}; |
| } |
| |
| template <class T, HWY_IF_UI32(T)> |
| HWY_API Vec512<T> Ror(Vec512<T> a, Vec512<T> b) { |
| return Vec512<T>{_mm512_rorv_epi32(a.raw, b.raw)}; |
| } |
| |
| template <class T, HWY_IF_UI64(T)> |
| HWY_API Vec512<T> Rol(Vec512<T> a, Vec512<T> b) { |
| return Vec512<T>{_mm512_rolv_epi64(a.raw, b.raw)}; |
| } |
| |
| template <class T, HWY_IF_UI64(T)> |
| HWY_API Vec512<T> Ror(Vec512<T> a, Vec512<T> b) { |
| return Vec512<T>{_mm512_rorv_epi64(a.raw, b.raw)}; |
| } |
| |
| // ------------------------------ ShiftLeftSame |
| |
| // GCC <14 and Clang <11 do not follow the Intel documentation for AVX-512 |
| // shift-with-immediate: the counts should all be unsigned int. Despite casting, |
| // we still see warnings in GCC debug builds, hence disable. |
| HWY_DIAGNOSTICS(push) |
| HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") |
| |
| #if HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1100 |
| using Shift16Count = int; |
| using Shift3264Count = int; |
| #elif HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1400 |
| // GCC 11.0 requires these, prior versions used a macro+cast and don't care. |
| using Shift16Count = int; |
| using Shift3264Count = unsigned int; |
| #else |
| // Assume documented behavior. Clang 11, GCC 14 and MSVC 14.28.29910 match this. |
| using Shift16Count = unsigned int; |
| using Shift3264Count = unsigned int; |
| #endif |
| |
| HWY_API Vec512<uint16_t> ShiftLeftSame(const Vec512<uint16_t> v, |
| const int bits) { |
| #if HWY_COMPILER_GCC |
| if (__builtin_constant_p(bits)) { |
| return Vec512<uint16_t>{ |
| _mm512_slli_epi16(v.raw, static_cast<Shift16Count>(bits))}; |
| } |
| #endif |
| return Vec512<uint16_t>{_mm512_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; |
| } |
| HWY_API Vec512<uint32_t> ShiftLeftSame(const Vec512<uint32_t> v, |
| const int bits) { |
| #if HWY_COMPILER_GCC |
| if (__builtin_constant_p(bits)) { |
| return Vec512<uint32_t>{ |
| _mm512_slli_epi32(v.raw, static_cast<Shift3264Count>(bits))}; |
| } |
| #endif |
| return Vec512<uint32_t>{_mm512_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; |
| } |
| HWY_API Vec512<uint64_t> ShiftLeftSame(const Vec512<uint64_t> v, |
| const int bits) { |
| #if HWY_COMPILER_GCC |
| if (__builtin_constant_p(bits)) { |
| return Vec512<uint64_t>{ |
| _mm512_slli_epi64(v.raw, static_cast<Shift3264Count>(bits))}; |
| } |
| #endif |
| return Vec512<uint64_t>{_mm512_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; |
| } |
| |
| HWY_API Vec512<int16_t> ShiftLeftSame(const Vec512<int16_t> v, const int bits) { |
| #if HWY_COMPILER_GCC |
| if (__builtin_constant_p(bits)) { |
| return Vec512<int16_t>{ |
| _mm512_slli_epi16(v.raw, static_cast<Shift16Count>(bits))}; |
| } |
| #endif |
| return Vec512<int16_t>{_mm512_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; |
| } |
| |
| HWY_API Vec512<int32_t> ShiftLeftSame(const Vec512<int32_t> v, const int bits) { |
| #if HWY_COMPILER_GCC |
| if (__builtin_constant_p(bits)) { |
| return Vec512<int32_t>{ |
| _mm512_slli_epi32(v.raw, static_cast<Shift3264Count>(bits))}; |
| } |
| #endif |
| return Vec512<int32_t>{_mm512_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; |
| } |
| |
| HWY_API Vec512<int64_t> ShiftLeftSame(const Vec512<int64_t> v, const int bits) { |
| #if HWY_COMPILER_GCC |
| if (__builtin_constant_p(bits)) { |
| return Vec512<int64_t>{ |
| _mm512_slli_epi64(v.raw, static_cast<Shift3264Count>(bits))}; |
| } |
| #endif |
| return Vec512<int64_t>{_mm512_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; |
| } |
| |
| template <typename T, HWY_IF_T_SIZE(T, 1)> |
| HWY_API Vec512<T> ShiftLeftSame(const Vec512<T> v, const int bits) { |
| const DFromV<decltype(v)> d8; |
| const RepartitionToWide<decltype(d8)> d16; |
| const auto shifted = BitCast(d8, ShiftLeftSame(BitCast(d16, v), bits)); |
| return shifted & Set(d8, static_cast<T>((0xFF << bits) & 0xFF)); |
| } |
| |
| // ------------------------------ ShiftRightSame |
| |
| HWY_API Vec512<uint16_t> ShiftRightSame(const Vec512<uint16_t> v, |
| const int bits) { |
| #if HWY_COMPILER_GCC |
| if (__builtin_constant_p(bits)) { |
| return Vec512<uint16_t>{ |
| _mm512_srli_epi16(v.raw, static_cast<Shift16Count>(bits))}; |
| } |
| #endif |
| return Vec512<uint16_t>{_mm512_srl_epi16(v.raw, _mm_cvtsi32_si128(bits))}; |
| } |
| HWY_API Vec512<uint32_t> ShiftRightSame(const Vec512<uint32_t> v, |
| const int bits) { |
| #if HWY_COMPILER_GCC |
| if (__builtin_constant_p(bits)) { |
| return Vec512<uint32_t>{ |
| _mm512_srli_epi32(v.raw, static_cast<Shift3264Count>(bits))}; |
| } |
| #endif |
| return Vec512<uint32_t>{_mm512_srl_epi32(v.raw, _mm_cvtsi32_si128(bits))}; |
| } |
| HWY_API Vec512<uint64_t> ShiftRightSame(const Vec512<uint64_t> v, |
| const int bits) { |
| #if HWY_COMPILER_GCC |
| if (__builtin_constant_p(bits)) { |
| return Vec512<uint64_t>{ |
| _mm512_srli_epi64(v.raw, static_cast<Shift3264Count>(bits))}; |
| } |
| #endif |
| return Vec512<uint64_t>{_mm512_srl_epi64(v.raw, _mm_cvtsi32_si128(bits))}; |
| } |
| |
| HWY_API Vec512<uint8_t> ShiftRightSame(Vec512<uint8_t> v, const int bits) { |
| const DFromV<decltype(v)> d8; |
| const RepartitionToWide<decltype(d8)> d16; |
| const auto shifted = BitCast(d8, ShiftRightSame(BitCast(d16, v), bits)); |
| return shifted & Set(d8, static_cast<uint8_t>(0xFF >> bits)); |
| } |
| |
| HWY_API Vec512<int16_t> ShiftRightSame(const Vec512<int16_t> v, |
| const int bits) { |
| #if HWY_COMPILER_GCC |
| if (__builtin_constant_p(bits)) { |
| return Vec512<int16_t>{ |
| _mm512_srai_epi16(v.raw, static_cast<Shift16Count>(bits))}; |
| } |
| #endif |
| return Vec512<int16_t>{_mm512_sra_epi16(v.raw, _mm_cvtsi32_si128(bits))}; |
| } |
| |
| HWY_API Vec512<int32_t> ShiftRightSame(const Vec512<int32_t> v, |
| const int bits) { |
| #if HWY_COMPILER_GCC |
| if (__builtin_constant_p(bits)) { |
| return Vec512<int32_t>{ |
| _mm512_srai_epi32(v.raw, static_cast<Shift3264Count>(bits))}; |
| } |
| #endif |
| return Vec512<int32_t>{_mm512_sra_epi32(v.raw, _mm_cvtsi32_si128(bits))}; |
| } |
| HWY_API Vec512<int64_t> ShiftRightSame(const Vec512<int64_t> v, |
| const int bits) { |
| #if HWY_COMPILER_GCC |
| if (__builtin_constant_p(bits)) { |
| return Vec512<int64_t>{ |
| _mm512_srai_epi64(v.raw, static_cast<Shift3264Count>(bits))}; |
| } |
| #endif |
| return Vec512<int64_t>{_mm512_sra_epi64(v.raw, _mm_cvtsi32_si128(bits))}; |
| } |
| |
| HWY_API Vec512<int8_t> ShiftRightSame(Vec512<int8_t> v, const int bits) { |
| const DFromV<decltype(v)> di; |
| const RebindToUnsigned<decltype(di)> du; |
| const auto shifted = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); |
| const auto shifted_sign = |
| BitCast(di, Set(du, static_cast<uint8_t>(0x80 >> bits))); |
| return (shifted ^ shifted_sign) - shifted_sign; |
| } |
| |
| HWY_DIAGNOSTICS(pop) |
| |
| // ------------------------------ Minimum |
| |
| // Unsigned |
| HWY_API Vec512<uint8_t> Min(Vec512<uint8_t> a, Vec512<uint8_t> b) { |
| return Vec512<uint8_t>{_mm512_min_epu8(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<uint16_t> Min(Vec512<uint16_t> a, Vec512<uint16_t> b) { |
| return Vec512<uint16_t>{_mm512_min_epu16(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<uint32_t> Min(Vec512<uint32_t> a, Vec512<uint32_t> b) { |
| return Vec512<uint32_t>{_mm512_min_epu32(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<uint64_t> Min(Vec512<uint64_t> a, Vec512<uint64_t> b) { |
| return Vec512<uint64_t>{_mm512_min_epu64(a.raw, b.raw)}; |
| } |
| |
| // Signed |
| HWY_API Vec512<int8_t> Min(Vec512<int8_t> a, Vec512<int8_t> b) { |
| return Vec512<int8_t>{_mm512_min_epi8(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<int16_t> Min(Vec512<int16_t> a, Vec512<int16_t> b) { |
| return Vec512<int16_t>{_mm512_min_epi16(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<int32_t> Min(Vec512<int32_t> a, Vec512<int32_t> b) { |
| return Vec512<int32_t>{_mm512_min_epi32(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<int64_t> Min(Vec512<int64_t> a, Vec512<int64_t> b) { |
| return Vec512<int64_t>{_mm512_min_epi64(a.raw, b.raw)}; |
| } |
| |
| // Float |
| #if HWY_HAVE_FLOAT16 |
| HWY_API Vec512<float16_t> Min(Vec512<float16_t> a, Vec512<float16_t> b) { |
| return Vec512<float16_t>{_mm512_min_ph(a.raw, b.raw)}; |
| } |
| #endif // HWY_HAVE_FLOAT16 |
| HWY_API Vec512<float> Min(Vec512<float> a, Vec512<float> b) { |
| return Vec512<float>{_mm512_min_ps(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<double> Min(Vec512<double> a, Vec512<double> b) { |
| return Vec512<double>{_mm512_min_pd(a.raw, b.raw)}; |
| } |
| |
| // ------------------------------ Maximum |
| |
| // Unsigned |
| HWY_API Vec512<uint8_t> Max(Vec512<uint8_t> a, Vec512<uint8_t> b) { |
| return Vec512<uint8_t>{_mm512_max_epu8(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<uint16_t> Max(Vec512<uint16_t> a, Vec512<uint16_t> b) { |
| return Vec512<uint16_t>{_mm512_max_epu16(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<uint32_t> Max(Vec512<uint32_t> a, Vec512<uint32_t> b) { |
| return Vec512<uint32_t>{_mm512_max_epu32(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<uint64_t> Max(Vec512<uint64_t> a, Vec512<uint64_t> b) { |
| return Vec512<uint64_t>{_mm512_max_epu64(a.raw, b.raw)}; |
| } |
| |
| // Signed |
| HWY_API Vec512<int8_t> Max(Vec512<int8_t> a, Vec512<int8_t> b) { |
| return Vec512<int8_t>{_mm512_max_epi8(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<int16_t> Max(Vec512<int16_t> a, Vec512<int16_t> b) { |
| return Vec512<int16_t>{_mm512_max_epi16(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<int32_t> Max(Vec512<int32_t> a, Vec512<int32_t> b) { |
| return Vec512<int32_t>{_mm512_max_epi32(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<int64_t> Max(Vec512<int64_t> a, Vec512<int64_t> b) { |
| return Vec512<int64_t>{_mm512_max_epi64(a.raw, b.raw)}; |
| } |
| |
| // Float |
| #if HWY_HAVE_FLOAT16 |
| HWY_API Vec512<float16_t> Max(Vec512<float16_t> a, Vec512<float16_t> b) { |
| return Vec512<float16_t>{_mm512_max_ph(a.raw, b.raw)}; |
| } |
| #endif // HWY_HAVE_FLOAT16 |
| HWY_API Vec512<float> Max(Vec512<float> a, Vec512<float> b) { |
| return Vec512<float>{_mm512_max_ps(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<double> Max(Vec512<double> a, Vec512<double> b) { |
| return Vec512<double>{_mm512_max_pd(a.raw, b.raw)}; |
| } |
| |
| // ------------------------------ Integer multiplication |
| |
| // Unsigned |
| HWY_API Vec512<uint16_t> operator*(Vec512<uint16_t> a, Vec512<uint16_t> b) { |
| return Vec512<uint16_t>{_mm512_mullo_epi16(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<uint32_t> operator*(Vec512<uint32_t> a, Vec512<uint32_t> b) { |
| return Vec512<uint32_t>{_mm512_mullo_epi32(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<uint64_t> operator*(Vec512<uint64_t> a, Vec512<uint64_t> b) { |
| return Vec512<uint64_t>{_mm512_mullo_epi64(a.raw, b.raw)}; |
| } |
| |
| // Signed |
| HWY_API Vec512<int16_t> operator*(Vec512<int16_t> a, Vec512<int16_t> b) { |
| return Vec512<int16_t>{_mm512_mullo_epi16(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<int32_t> operator*(Vec512<int32_t> a, Vec512<int32_t> b) { |
| return Vec512<int32_t>{_mm512_mullo_epi32(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<int64_t> operator*(Vec512<int64_t> a, Vec512<int64_t> b) { |
| return Vec512<int64_t>{_mm512_mullo_epi64(a.raw, b.raw)}; |
| } |
| |
| // Returns the upper 16 bits of a * b in each lane. |
| HWY_API Vec512<uint16_t> MulHigh(Vec512<uint16_t> a, Vec512<uint16_t> b) { |
| return Vec512<uint16_t>{_mm512_mulhi_epu16(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<int16_t> MulHigh(Vec512<int16_t> a, Vec512<int16_t> b) { |
| return Vec512<int16_t>{_mm512_mulhi_epi16(a.raw, b.raw)}; |
| } |
| |
| HWY_API Vec512<int16_t> MulFixedPoint15(Vec512<int16_t> a, Vec512<int16_t> b) { |
| return Vec512<int16_t>{_mm512_mulhrs_epi16(a.raw, b.raw)}; |
| } |
| |
| // Multiplies even lanes (0, 2 ..) and places the double-wide result into |
| // even and the upper half into its odd neighbor lane. |
| HWY_API Vec512<int64_t> MulEven(Vec512<int32_t> a, Vec512<int32_t> b) { |
| return Vec512<int64_t>{_mm512_mul_epi32(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<uint64_t> MulEven(Vec512<uint32_t> a, Vec512<uint32_t> b) { |
| return Vec512<uint64_t>{_mm512_mul_epu32(a.raw, b.raw)}; |
| } |
| |
| // ------------------------------ Neg (Sub) |
| |
| template <typename T, HWY_IF_FLOAT_OR_SPECIAL(T)> |
| HWY_API Vec512<T> Neg(const Vec512<T> v) { |
| const DFromV<decltype(v)> d; |
| return Xor(v, SignBit(d)); |
| } |
| |
| template <typename T, HWY_IF_NOT_FLOAT_NOR_SPECIAL(T)> |
| HWY_API Vec512<T> Neg(const Vec512<T> v) { |
| const DFromV<decltype(v)> d; |
| return Zero(d) - v; |
| } |
| |
| // ------------------------------ Floating-point mul / div |
| |
| #if HWY_HAVE_FLOAT16 |
| HWY_API Vec512<float16_t> operator*(Vec512<float16_t> a, Vec512<float16_t> b) { |
| return Vec512<float16_t>{_mm512_mul_ph(a.raw, b.raw)}; |
| } |
| #endif // HWY_HAVE_FLOAT16 |
| HWY_API Vec512<float> operator*(Vec512<float> a, Vec512<float> b) { |
| return Vec512<float>{_mm512_mul_ps(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<double> operator*(Vec512<double> a, Vec512<double> b) { |
| return Vec512<double>{_mm512_mul_pd(a.raw, b.raw)}; |
| } |
| |
| #if HWY_HAVE_FLOAT16 |
| HWY_API Vec512<float16_t> MulByFloorPow2(Vec512<float16_t> a, |
| Vec512<float16_t> b) { |
| return Vec512<float16_t>{_mm512_scalef_ph(a.raw, b.raw)}; |
| } |
| #endif |
| |
| HWY_API Vec512<float> MulByFloorPow2(Vec512<float> a, Vec512<float> b) { |
| return Vec512<float>{_mm512_scalef_ps(a.raw, b.raw)}; |
| } |
| |
| HWY_API Vec512<double> MulByFloorPow2(Vec512<double> a, Vec512<double> b) { |
| return Vec512<double>{_mm512_scalef_pd(a.raw, b.raw)}; |
| } |
| |
| #if HWY_HAVE_FLOAT16 |
| HWY_API Vec512<float16_t> operator/(Vec512<float16_t> a, Vec512<float16_t> b) { |
| return Vec512<float16_t>{_mm512_div_ph(a.raw, b.raw)}; |
| } |
| #endif // HWY_HAVE_FLOAT16 |
| HWY_API Vec512<float> operator/(Vec512<float> a, Vec512<float> b) { |
| return Vec512<float>{_mm512_div_ps(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<double> operator/(Vec512<double> a, Vec512<double> b) { |
| return Vec512<double>{_mm512_div_pd(a.raw, b.raw)}; |
| } |
| |
| // Approximate reciprocal |
| #if HWY_HAVE_FLOAT16 |
| HWY_API Vec512<float16_t> ApproximateReciprocal(const Vec512<float16_t> v) { |
| return Vec512<float16_t>{_mm512_rcp_ph(v.raw)}; |
| } |
| #endif // HWY_HAVE_FLOAT16 |
| HWY_API Vec512<float> ApproximateReciprocal(const Vec512<float> v) { |
| return Vec512<float>{_mm512_rcp14_ps(v.raw)}; |
| } |
| |
| HWY_API Vec512<double> ApproximateReciprocal(Vec512<double> v) { |
| return Vec512<double>{_mm512_rcp14_pd(v.raw)}; |
| } |
| |
| // ------------------------------ GetExponent |
| |
| #if HWY_HAVE_FLOAT16 |
| template <class V, HWY_IF_F16(TFromV<V>), HWY_IF_V_SIZE_V(V, 64)> |
| HWY_API V GetExponent(V v) { |
| return V{_mm512_getexp_ph(v.raw)}; |
| } |
| #endif |
| template <class V, HWY_IF_F32(TFromV<V>), HWY_IF_V_SIZE_V(V, 64)> |
| HWY_API V GetExponent(V v) { |
| return V{_mm512_getexp_ps(v.raw)}; |
| } |
| template <class V, HWY_IF_F64(TFromV<V>), HWY_IF_V_SIZE_V(V, 64)> |
| HWY_API V GetExponent(V v) { |
| return V{_mm512_getexp_pd(v.raw)}; |
| } |
| |
| // ------------------------------ MaskedMinOr |
| |
| template <typename T, HWY_IF_U8(T)> |
| HWY_API Vec512<T> MaskedMinOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_min_epu8(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| template <typename T, HWY_IF_I8(T)> |
| HWY_API Vec512<T> MaskedMinOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_min_epi8(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| |
| template <typename T, HWY_IF_U16(T)> |
| HWY_API Vec512<T> MaskedMinOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_min_epu16(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| template <typename T, HWY_IF_I16(T)> |
| HWY_API Vec512<T> MaskedMinOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_min_epi16(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| |
| template <typename T, HWY_IF_U32(T)> |
| HWY_API Vec512<T> MaskedMinOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_min_epu32(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| template <typename T, HWY_IF_I32(T)> |
| HWY_API Vec512<T> MaskedMinOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_min_epi32(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| |
| template <typename T, HWY_IF_U64(T)> |
| HWY_API Vec512<T> MaskedMinOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_min_epu64(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| template <typename T, HWY_IF_I64(T)> |
| HWY_API Vec512<T> MaskedMinOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_min_epi64(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| |
| template <typename T, HWY_IF_F32(T)> |
| HWY_API Vec512<T> MaskedMinOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_min_ps(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| |
| template <typename T, HWY_IF_F64(T)> |
| HWY_API Vec512<T> MaskedMinOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_min_pd(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| |
| #if HWY_HAVE_FLOAT16 |
| template <typename T, HWY_IF_F16(T)> |
| HWY_API Vec512<T> MaskedMinOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_min_ph(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| #endif // HWY_HAVE_FLOAT16 |
| |
| // ------------------------------ MaskedMaxOr |
| |
| template <typename T, HWY_IF_U8(T)> |
| HWY_API Vec512<T> MaskedMaxOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_max_epu8(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| template <typename T, HWY_IF_I8(T)> |
| HWY_API Vec512<T> MaskedMaxOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_max_epi8(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| |
| template <typename T, HWY_IF_U16(T)> |
| HWY_API Vec512<T> MaskedMaxOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_max_epu16(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| template <typename T, HWY_IF_I16(T)> |
| HWY_API Vec512<T> MaskedMaxOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_max_epi16(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| |
| template <typename T, HWY_IF_U32(T)> |
| HWY_API Vec512<T> MaskedMaxOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_max_epu32(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| template <typename T, HWY_IF_I32(T)> |
| HWY_API Vec512<T> MaskedMaxOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_max_epi32(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| |
| template <typename T, HWY_IF_U64(T)> |
| HWY_API Vec512<T> MaskedMaxOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_max_epu64(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| template <typename T, HWY_IF_I64(T)> |
| HWY_API Vec512<T> MaskedMaxOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_max_epi64(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| |
| template <typename T, HWY_IF_F32(T)> |
| HWY_API Vec512<T> MaskedMaxOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_max_ps(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| |
| template <typename T, HWY_IF_F64(T)> |
| HWY_API Vec512<T> MaskedMaxOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_max_pd(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| |
| #if HWY_HAVE_FLOAT16 |
| template <typename T, HWY_IF_F16(T)> |
| HWY_API Vec512<T> MaskedMaxOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_max_ph(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| #endif // HWY_HAVE_FLOAT16 |
| |
| // ------------------------------ MaskedAddOr |
| |
| template <typename T, HWY_IF_UI8(T)> |
| HWY_API Vec512<T> MaskedAddOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_add_epi8(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| |
| template <typename T, HWY_IF_UI16(T)> |
| HWY_API Vec512<T> MaskedAddOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_add_epi16(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| |
| template <typename T, HWY_IF_UI32(T)> |
| HWY_API Vec512<T> MaskedAddOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_add_epi32(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| |
| template <typename T, HWY_IF_UI64(T)> |
| HWY_API Vec512<T> MaskedAddOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_add_epi64(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| |
| template <typename T, HWY_IF_F32(T)> |
| HWY_API Vec512<T> MaskedAddOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_add_ps(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| |
| template <typename T, HWY_IF_F64(T)> |
| HWY_API Vec512<T> MaskedAddOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_add_pd(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| |
| #if HWY_HAVE_FLOAT16 |
| template <typename T, HWY_IF_F16(T)> |
| HWY_API Vec512<T> MaskedAddOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_add_ph(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| #endif // HWY_HAVE_FLOAT16 |
| |
| // ------------------------------ MaskedSubOr |
| |
| template <typename T, HWY_IF_UI8(T)> |
| HWY_API Vec512<T> MaskedSubOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_sub_epi8(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| |
| template <typename T, HWY_IF_UI16(T)> |
| HWY_API Vec512<T> MaskedSubOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_sub_epi16(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| |
| template <typename T, HWY_IF_UI32(T)> |
| HWY_API Vec512<T> MaskedSubOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_sub_epi32(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| |
| template <typename T, HWY_IF_UI64(T)> |
| HWY_API Vec512<T> MaskedSubOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_sub_epi64(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| |
| template <typename T, HWY_IF_F32(T)> |
| HWY_API Vec512<T> MaskedSubOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_sub_ps(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| |
| template <typename T, HWY_IF_F64(T)> |
| HWY_API Vec512<T> MaskedSubOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_sub_pd(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| |
| #if HWY_HAVE_FLOAT16 |
| template <typename T, HWY_IF_F16(T)> |
| HWY_API Vec512<T> MaskedSubOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_sub_ph(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| #endif // HWY_HAVE_FLOAT16 |
| |
| // ------------------------------ MaskedMulOr |
| |
| HWY_API Vec512<float> MaskedMulOr(Vec512<float> no, Mask512<float> m, |
| Vec512<float> a, Vec512<float> b) { |
| return Vec512<float>{_mm512_mask_mul_ps(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| |
| HWY_API Vec512<double> MaskedMulOr(Vec512<double> no, Mask512<double> m, |
| Vec512<double> a, Vec512<double> b) { |
| return Vec512<double>{_mm512_mask_mul_pd(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| |
| #if HWY_HAVE_FLOAT16 |
| HWY_API Vec512<float16_t> MaskedMulOr(Vec512<float16_t> no, |
| Mask512<float16_t> m, Vec512<float16_t> a, |
| Vec512<float16_t> b) { |
| return Vec512<float16_t>{_mm512_mask_mul_ph(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| #endif // HWY_HAVE_FLOAT16 |
| |
| // ------------------------------ MaskedDivOr |
| |
| HWY_API Vec512<float> MaskedDivOr(Vec512<float> no, Mask512<float> m, |
| Vec512<float> a, Vec512<float> b) { |
| return Vec512<float>{_mm512_mask_div_ps(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| |
| HWY_API Vec512<double> MaskedDivOr(Vec512<double> no, Mask512<double> m, |
| Vec512<double> a, Vec512<double> b) { |
| return Vec512<double>{_mm512_mask_div_pd(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| |
| #if HWY_HAVE_FLOAT16 |
| HWY_API Vec512<float16_t> MaskedDivOr(Vec512<float16_t> no, |
| Mask512<float16_t> m, Vec512<float16_t> a, |
| Vec512<float16_t> b) { |
| return Vec512<float16_t>{_mm512_mask_div_ph(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| #endif // HWY_HAVE_FLOAT16 |
| |
| // ------------------------------ MaskedSatAddOr |
| |
| template <typename T, HWY_IF_I8(T)> |
| HWY_API Vec512<T> MaskedSatAddOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_adds_epi8(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| |
| template <typename T, HWY_IF_U8(T)> |
| HWY_API Vec512<T> MaskedSatAddOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_adds_epu8(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| |
| template <typename T, HWY_IF_I16(T)> |
| HWY_API Vec512<T> MaskedSatAddOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_adds_epi16(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| |
| template <typename T, HWY_IF_U16(T)> |
| HWY_API Vec512<T> MaskedSatAddOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_adds_epu16(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| |
| // ------------------------------ MaskedSatSubOr |
| |
| template <typename T, HWY_IF_I8(T)> |
| HWY_API Vec512<T> MaskedSatSubOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_subs_epi8(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| |
| template <typename T, HWY_IF_U8(T)> |
| HWY_API Vec512<T> MaskedSatSubOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_subs_epu8(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| |
| template <typename T, HWY_IF_I16(T)> |
| HWY_API Vec512<T> MaskedSatSubOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_subs_epi16(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| |
| template <typename T, HWY_IF_U16(T)> |
| HWY_API Vec512<T> MaskedSatSubOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, |
| Vec512<T> b) { |
| return Vec512<T>{_mm512_mask_subs_epu16(no.raw, m.raw, a.raw, b.raw)}; |
| } |
| |
| // ------------------------------ Floating-point multiply-add variants |
| |
| #if HWY_HAVE_FLOAT16 |
| |
| HWY_API Vec512<float16_t> MulAdd(Vec512<float16_t> mul, Vec512<float16_t> x, |
| Vec512<float16_t> add) { |
| return Vec512<float16_t>{_mm512_fmadd_ph(mul.raw, x.raw, add.raw)}; |
| } |
| |
| HWY_API Vec512<float16_t> NegMulAdd(Vec512<float16_t> mul, Vec512<float16_t> x, |
| Vec512<float16_t> add) { |
| return Vec512<float16_t>{_mm512_fnmadd_ph(mul.raw, x.raw, add.raw)}; |
| } |
| |
| HWY_API Vec512<float16_t> MulSub(Vec512<float16_t> mul, Vec512<float16_t> x, |
| Vec512<float16_t> sub) { |
| return Vec512<float16_t>{_mm512_fmsub_ph(mul.raw, x.raw, sub.raw)}; |
| } |
| |
| HWY_API Vec512<float16_t> NegMulSub(Vec512<float16_t> mul, Vec512<float16_t> x, |
| Vec512<float16_t> sub) { |
| return Vec512<float16_t>{_mm512_fnmsub_ph(mul.raw, x.raw, sub.raw)}; |
| } |
| |
| #endif // HWY_HAVE_FLOAT16 |
| |
| // Returns mul * x + add |
| HWY_API Vec512<float> MulAdd(Vec512<float> mul, Vec512<float> x, |
| Vec512<float> add) { |
| return Vec512<float>{_mm512_fmadd_ps(mul.raw, x.raw, add.raw)}; |
| } |
| HWY_API Vec512<double> MulAdd(Vec512<double> mul, Vec512<double> x, |
| Vec512<double> add) { |
| return Vec512<double>{_mm512_fmadd_pd(mul.raw, x.raw, add.raw)}; |
| } |
| |
| // Returns add - mul * x |
| HWY_API Vec512<float> NegMulAdd(Vec512<float> mul, Vec512<float> x, |
| Vec512<float> add) { |
| return Vec512<float>{_mm512_fnmadd_ps(mul.raw, x.raw, add.raw)}; |
| } |
| HWY_API Vec512<double> NegMulAdd(Vec512<double> mul, Vec512<double> x, |
| Vec512<double> add) { |
| return Vec512<double>{_mm512_fnmadd_pd(mul.raw, x.raw, add.raw)}; |
| } |
| |
| // Returns mul * x - sub |
| HWY_API Vec512<float> MulSub(Vec512<float> mul, Vec512<float> x, |
| Vec512<float> sub) { |
| return Vec512<float>{_mm512_fmsub_ps(mul.raw, x.raw, sub.raw)}; |
| } |
| HWY_API Vec512<double> MulSub(Vec512<double> mul, Vec512<double> x, |
| Vec512<double> sub) { |
| return Vec512<double>{_mm512_fmsub_pd(mul.raw, x.raw, sub.raw)}; |
| } |
| |
| // Returns -mul * x - sub |
| HWY_API Vec512<float> NegMulSub(Vec512<float> mul, Vec512<float> x, |
| Vec512<float> sub) { |
| return Vec512<float>{_mm512_fnmsub_ps(mul.raw, x.raw, sub.raw)}; |
| } |
| HWY_API Vec512<double> NegMulSub(Vec512<double> mul, Vec512<double> x, |
| Vec512<double> sub) { |
| return Vec512<double>{_mm512_fnmsub_pd(mul.raw, x.raw, sub.raw)}; |
| } |
| |
| #if HWY_HAVE_FLOAT16 |
| HWY_API Vec512<float16_t> MulAddSub(Vec512<float16_t> mul, Vec512<float16_t> x, |
| Vec512<float16_t> sub_or_add) { |
| return Vec512<float16_t>{_mm512_fmaddsub_ph(mul.raw, x.raw, sub_or_add.raw)}; |
| } |
| #endif // HWY_HAVE_FLOAT16 |
| |
| HWY_API Vec512<float> MulAddSub(Vec512<float> mul, Vec512<float> x, |
| Vec512<float> sub_or_add) { |
| return Vec512<float>{_mm512_fmaddsub_ps(mul.raw, x.raw, sub_or_add.raw)}; |
| } |
| |
| HWY_API Vec512<double> MulAddSub(Vec512<double> mul, Vec512<double> x, |
| Vec512<double> sub_or_add) { |
| return Vec512<double>{_mm512_fmaddsub_pd(mul.raw, x.raw, sub_or_add.raw)}; |
| } |
| |
| // ------------------------------ Floating-point square root |
| |
| // Full precision square root |
| #if HWY_HAVE_FLOAT16 |
| HWY_API Vec512<float16_t> Sqrt(const Vec512<float16_t> v) { |
| return Vec512<float16_t>{_mm512_sqrt_ph(v.raw)}; |
| } |
| #endif // HWY_HAVE_FLOAT16 |
| HWY_API Vec512<float> Sqrt(const Vec512<float> v) { |
| return Vec512<float>{_mm512_sqrt_ps(v.raw)}; |
| } |
| HWY_API Vec512<double> Sqrt(const Vec512<double> v) { |
| return Vec512<double>{_mm512_sqrt_pd(v.raw)}; |
| } |
| |
| // Approximate reciprocal square root |
| #if HWY_HAVE_FLOAT16 |
| HWY_API Vec512<float16_t> ApproximateReciprocalSqrt(Vec512<float16_t> v) { |
| return Vec512<float16_t>{_mm512_rsqrt_ph(v.raw)}; |
| } |
| #endif // HWY_HAVE_FLOAT16 |
| HWY_API Vec512<float> ApproximateReciprocalSqrt(Vec512<float> v) { |
| return Vec512<float>{_mm512_rsqrt14_ps(v.raw)}; |
| } |
| |
| HWY_API Vec512<double> ApproximateReciprocalSqrt(Vec512<double> v) { |
| return Vec512<double>{_mm512_rsqrt14_pd(v.raw)}; |
| } |
| |
| // ------------------------------ Floating-point rounding |
| |
| // Work around warnings in the intrinsic definitions (passing -1 as a mask). |
| HWY_DIAGNOSTICS(push) |
| HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") |
| |
| // Toward nearest integer, tie to even |
| #if HWY_HAVE_FLOAT16 |
| HWY_API Vec512<float16_t> Round(Vec512<float16_t> v) { |
| return Vec512<float16_t>{_mm512_roundscale_ph( |
| v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; |
| } |
| #endif // HWY_HAVE_FLOAT16 |
| HWY_API Vec512<float> Round(Vec512<float> v) { |
| return Vec512<float>{_mm512_roundscale_ps( |
| v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; |
| } |
| HWY_API Vec512<double> Round(Vec512<double> v) { |
| return Vec512<double>{_mm512_roundscale_pd( |
| v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; |
| } |
| |
| // Toward zero, aka truncate |
| #if HWY_HAVE_FLOAT16 |
| HWY_API Vec512<float16_t> Trunc(Vec512<float16_t> v) { |
| return Vec512<float16_t>{ |
| _mm512_roundscale_ph(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; |
| } |
| #endif // HWY_HAVE_FLOAT16 |
| HWY_API Vec512<float> Trunc(Vec512<float> v) { |
| return Vec512<float>{ |
| _mm512_roundscale_ps(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; |
| } |
| HWY_API Vec512<double> Trunc(Vec512<double> v) { |
| return Vec512<double>{ |
| _mm512_roundscale_pd(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; |
| } |
| |
| // Toward +infinity, aka ceiling |
| #if HWY_HAVE_FLOAT16 |
| HWY_API Vec512<float16_t> Ceil(Vec512<float16_t> v) { |
| return Vec512<float16_t>{ |
| _mm512_roundscale_ph(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; |
| } |
| #endif // HWY_HAVE_FLOAT16 |
| HWY_API Vec512<float> Ceil(Vec512<float> v) { |
| return Vec512<float>{ |
| _mm512_roundscale_ps(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; |
| } |
| HWY_API Vec512<double> Ceil(Vec512<double> v) { |
| return Vec512<double>{ |
| _mm512_roundscale_pd(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; |
| } |
| |
| // Toward -infinity, aka floor |
| #if HWY_HAVE_FLOAT16 |
| HWY_API Vec512<float16_t> Floor(Vec512<float16_t> v) { |
| return Vec512<float16_t>{ |
| _mm512_roundscale_ph(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; |
| } |
| #endif // HWY_HAVE_FLOAT16 |
| HWY_API Vec512<float> Floor(Vec512<float> v) { |
| return Vec512<float>{ |
| _mm512_roundscale_ps(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; |
| } |
| HWY_API Vec512<double> Floor(Vec512<double> v) { |
| return Vec512<double>{ |
| _mm512_roundscale_pd(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; |
| } |
| |
| HWY_DIAGNOSTICS(pop) |
| |
| // ================================================== COMPARE |
| |
| // Comparisons set a mask bit to 1 if the condition is true, else 0. |
| |
| template <class DTo, typename TFrom> |
| HWY_API MFromD<DTo> RebindMask(DTo /*tag*/, Mask512<TFrom> m) { |
| static_assert(sizeof(TFrom) == sizeof(TFromD<DTo>), "Must have same size"); |
| return MFromD<DTo>{m.raw}; |
| } |
| |
| namespace detail { |
| |
| template <typename T> |
| HWY_INLINE Mask512<T> TestBit(hwy::SizeTag<1> /*tag*/, Vec512<T> v, |
| Vec512<T> bit) { |
| return Mask512<T>{_mm512_test_epi8_mask(v.raw, bit.raw)}; |
| } |
| template <typename T> |
| HWY_INLINE Mask512<T> TestBit(hwy::SizeTag<2> /*tag*/, Vec512<T> v, |
| Vec512<T> bit) { |
| return Mask512<T>{_mm512_test_epi16_mask(v.raw, bit.raw)}; |
| } |
| template <typename T> |
| HWY_INLINE Mask512<T> TestBit(hwy::SizeTag<4> /*tag*/, Vec512<T> v, |
| Vec512<T> bit) { |
| return Mask512<T>{_mm512_test_epi32_mask(v.raw, bit.raw)}; |
| } |
| template <typename T> |
| HWY_INLINE Mask512<T> TestBit(hwy::SizeTag<8> /*tag*/, Vec512<T> v, |
| Vec512<T> bit) { |
| return Mask512<T>{_mm512_test_epi64_mask(v.raw, bit.raw)}; |
| } |
| |
| } // namespace detail |
| |
| template <typename T> |
| HWY_API Mask512<T> TestBit(const Vec512<T> v, const Vec512<T> bit) { |
| static_assert(!hwy::IsFloat<T>(), "Only integer vectors supported"); |
| return detail::TestBit(hwy::SizeTag<sizeof(T)>(), v, bit); |
| } |
| |
| // ------------------------------ Equality |
| |
| template <typename T, HWY_IF_T_SIZE(T, 1)> |
| HWY_API Mask512<T> operator==(Vec512<T> a, Vec512<T> b) { |
| return Mask512<T>{_mm512_cmpeq_epi8_mask(a.raw, b.raw)}; |
| } |
| template <typename T, HWY_IF_T_SIZE(T, 2)> |
| HWY_API Mask512<T> operator==(Vec512<T> a, Vec512<T> b) { |
| return Mask512<T>{_mm512_cmpeq_epi16_mask(a.raw, b.raw)}; |
| } |
| template <typename T, HWY_IF_UI32(T)> |
| HWY_API Mask512<T> operator==(Vec512<T> a, Vec512<T> b) { |
| return Mask512<T>{_mm512_cmpeq_epi32_mask(a.raw, b.raw)}; |
| } |
| template <typename T, HWY_IF_UI64(T)> |
| HWY_API Mask512<T> operator==(Vec512<T> a, Vec512<T> b) { |
| return Mask512<T>{_mm512_cmpeq_epi64_mask(a.raw, b.raw)}; |
| } |
| |
| #if HWY_HAVE_FLOAT16 |
| HWY_API Mask512<float16_t> operator==(Vec512<float16_t> a, |
| Vec512<float16_t> b) { |
| // Work around warnings in the intrinsic definitions (passing -1 as a mask). |
| HWY_DIAGNOSTICS(push) |
| HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") |
| return Mask512<float16_t>{_mm512_cmp_ph_mask(a.raw, b.raw, _CMP_EQ_OQ)}; |
| HWY_DIAGNOSTICS(pop) |
| } |
| #endif // HWY_HAVE_FLOAT16 |
| |
| HWY_API Mask512<float> operator==(Vec512<float> a, Vec512<float> b) { |
| return Mask512<float>{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_EQ_OQ)}; |
| } |
| |
| HWY_API Mask512<double> operator==(Vec512<double> a, Vec512<double> b) { |
| return Mask512<double>{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_EQ_OQ)}; |
| } |
| |
| // ------------------------------ Inequality |
| |
| template <typename T, HWY_IF_T_SIZE(T, 1)> |
| HWY_API Mask512<T> operator!=(Vec512<T> a, Vec512<T> b) { |
| return Mask512<T>{_mm512_cmpneq_epi8_mask(a.raw, b.raw)}; |
| } |
| template <typename T, HWY_IF_T_SIZE(T, 2)> |
| HWY_API Mask512<T> operator!=(Vec512<T> a, Vec512<T> b) { |
| return Mask512<T>{_mm512_cmpneq_epi16_mask(a.raw, b.raw)}; |
| } |
| template <typename T, HWY_IF_UI32(T)> |
| HWY_API Mask512<T> operator!=(Vec512<T> a, Vec512<T> b) { |
| return Mask512<T>{_mm512_cmpneq_epi32_mask(a.raw, b.raw)}; |
| } |
| template <typename T, HWY_IF_UI64(T)> |
| HWY_API Mask512<T> operator!=(Vec512<T> a, Vec512<T> b) { |
| return Mask512<T>{_mm512_cmpneq_epi64_mask(a.raw, b.raw)}; |
| } |
| |
| #if HWY_HAVE_FLOAT16 |
| HWY_API Mask512<float16_t> operator!=(Vec512<float16_t> a, |
| Vec512<float16_t> b) { |
| // Work around warnings in the intrinsic definitions (passing -1 as a mask). |
| HWY_DIAGNOSTICS(push) |
| HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") |
| return Mask512<float16_t>{_mm512_cmp_ph_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; |
| HWY_DIAGNOSTICS(pop) |
| } |
| #endif // HWY_HAVE_FLOAT16 |
| |
| HWY_API Mask512<float> operator!=(Vec512<float> a, Vec512<float> b) { |
| return Mask512<float>{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; |
| } |
| |
| HWY_API Mask512<double> operator!=(Vec512<double> a, Vec512<double> b) { |
| return Mask512<double>{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; |
| } |
| |
| // ------------------------------ Strict inequality |
| |
| HWY_API Mask512<uint8_t> operator>(Vec512<uint8_t> a, Vec512<uint8_t> b) { |
| return Mask512<uint8_t>{_mm512_cmpgt_epu8_mask(a.raw, b.raw)}; |
| } |
| HWY_API Mask512<uint16_t> operator>(Vec512<uint16_t> a, Vec512<uint16_t> b) { |
| return Mask512<uint16_t>{_mm512_cmpgt_epu16_mask(a.raw, b.raw)}; |
| } |
| HWY_API Mask512<uint32_t> operator>(Vec512<uint32_t> a, Vec512<uint32_t> b) { |
| return Mask512<uint32_t>{_mm512_cmpgt_epu32_mask(a.raw, b.raw)}; |
| } |
| HWY_API Mask512<uint64_t> operator>(Vec512<uint64_t> a, Vec512<uint64_t> b) { |
| return Mask512<uint64_t>{_mm512_cmpgt_epu64_mask(a.raw, b.raw)}; |
| } |
| |
| HWY_API Mask512<int8_t> operator>(Vec512<int8_t> a, Vec512<int8_t> b) { |
| return Mask512<int8_t>{_mm512_cmpgt_epi8_mask(a.raw, b.raw)}; |
| } |
| HWY_API Mask512<int16_t> operator>(Vec512<int16_t> a, Vec512<int16_t> b) { |
| return Mask512<int16_t>{_mm512_cmpgt_epi16_mask(a.raw, b.raw)}; |
| } |
| HWY_API Mask512<int32_t> operator>(Vec512<int32_t> a, Vec512<int32_t> b) { |
| return Mask512<int32_t>{_mm512_cmpgt_epi32_mask(a.raw, b.raw)}; |
| } |
| HWY_API Mask512<int64_t> operator>(Vec512<int64_t> a, Vec512<int64_t> b) { |
| return Mask512<int64_t>{_mm512_cmpgt_epi64_mask(a.raw, b.raw)}; |
| } |
| |
| #if HWY_HAVE_FLOAT16 |
| HWY_API Mask512<float16_t> operator>(Vec512<float16_t> a, Vec512<float16_t> b) { |
| // Work around warnings in the intrinsic definitions (passing -1 as a mask). |
| HWY_DIAGNOSTICS(push) |
| HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") |
| return Mask512<float16_t>{_mm512_cmp_ph_mask(a.raw, b.raw, _CMP_GT_OQ)}; |
| HWY_DIAGNOSTICS(pop) |
| } |
| #endif // HWY_HAVE_FLOAT16 |
| |
| HWY_API Mask512<float> operator>(Vec512<float> a, Vec512<float> b) { |
| return Mask512<float>{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_GT_OQ)}; |
| } |
| HWY_API Mask512<double> operator>(Vec512<double> a, Vec512<double> b) { |
| return Mask512<double>{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_GT_OQ)}; |
| } |
| |
| // ------------------------------ Weak inequality |
| |
| #if HWY_HAVE_FLOAT16 |
| HWY_API Mask512<float16_t> operator>=(Vec512<float16_t> a, |
| Vec512<float16_t> b) { |
| // Work around warnings in the intrinsic definitions (passing -1 as a mask). |
| HWY_DIAGNOSTICS(push) |
| HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") |
| return Mask512<float16_t>{_mm512_cmp_ph_mask(a.raw, b.raw, _CMP_GE_OQ)}; |
| HWY_DIAGNOSTICS(pop) |
| } |
| #endif // HWY_HAVE_FLOAT16 |
| |
| HWY_API Mask512<float> operator>=(Vec512<float> a, Vec512<float> b) { |
| return Mask512<float>{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_GE_OQ)}; |
| } |
| HWY_API Mask512<double> operator>=(Vec512<double> a, Vec512<double> b) { |
| return Mask512<double>{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_GE_OQ)}; |
| } |
| |
| HWY_API Mask512<uint8_t> operator>=(Vec512<uint8_t> a, Vec512<uint8_t> b) { |
| return Mask512<uint8_t>{_mm512_cmpge_epu8_mask(a.raw, b.raw)}; |
| } |
| HWY_API Mask512<uint16_t> operator>=(Vec512<uint16_t> a, Vec512<uint16_t> b) { |
| return Mask512<uint16_t>{_mm512_cmpge_epu16_mask(a.raw, b.raw)}; |
| } |
| HWY_API Mask512<uint32_t> operator>=(Vec512<uint32_t> a, Vec512<uint32_t> b) { |
| return Mask512<uint32_t>{_mm512_cmpge_epu32_mask(a.raw, b.raw)}; |
| } |
| HWY_API Mask512<uint64_t> operator>=(Vec512<uint64_t> a, Vec512<uint64_t> b) { |
| return Mask512<uint64_t>{_mm512_cmpge_epu64_mask(a.raw, b.raw)}; |
| } |
| |
| HWY_API Mask512<int8_t> operator>=(Vec512<int8_t> a, Vec512<int8_t> b) { |
| return Mask512<int8_t>{_mm512_cmpge_epi8_mask(a.raw, b.raw)}; |
| } |
| HWY_API Mask512<int16_t> operator>=(Vec512<int16_t> a, Vec512<int16_t> b) { |
| return Mask512<int16_t>{_mm512_cmpge_epi16_mask(a.raw, b.raw)}; |
| } |
| HWY_API Mask512<int32_t> operator>=(Vec512<int32_t> a, Vec512<int32_t> b) { |
| return Mask512<int32_t>{_mm512_cmpge_epi32_mask(a.raw, b.raw)}; |
| } |
| HWY_API Mask512<int64_t> operator>=(Vec512<int64_t> a, Vec512<int64_t> b) { |
| return Mask512<int64_t>{_mm512_cmpge_epi64_mask(a.raw, b.raw)}; |
| } |
| |
| // ------------------------------ Reversed comparisons |
| |
| template <typename T> |
| HWY_API Mask512<T> operator<(Vec512<T> a, Vec512<T> b) { |
| return b > a; |
| } |
| |
| template <typename T> |
| HWY_API Mask512<T> operator<=(Vec512<T> a, Vec512<T> b) { |
| return b >= a; |
| } |
| |
| // ------------------------------ Mask |
| |
| template <typename T, HWY_IF_UI8(T)> |
| HWY_API Mask512<T> MaskFromVec(Vec512<T> v) { |
| return Mask512<T>{_mm512_movepi8_mask(v.raw)}; |
| } |
| template <typename T, HWY_IF_UI16(T)> |
| HWY_API Mask512<T> MaskFromVec(Vec512<T> v) { |
| return Mask512<T>{_mm512_movepi16_mask(v.raw)}; |
| } |
| template <typename T, HWY_IF_UI32(T)> |
| HWY_API Mask512<T> MaskFromVec(Vec512<T> v) { |
| return Mask512<T>{_mm512_movepi32_mask(v.raw)}; |
| } |
| template <typename T, HWY_IF_UI64(T)> |
| HWY_API Mask512<T> MaskFromVec(Vec512<T> v) { |
| return Mask512<T>{_mm512_movepi64_mask(v.raw)}; |
| } |
| template <typename T, HWY_IF_FLOAT_OR_SPECIAL(T)> |
| HWY_API Mask512<T> MaskFromVec(Vec512<T> v) { |
| const RebindToSigned<DFromV<decltype(v)>> di; |
| return Mask512<T>{MaskFromVec(BitCast(di, v)).raw}; |
| } |
| |
| template <typename T, HWY_IF_UI8(T)> |
| HWY_API Vec512<T> VecFromMask(Mask512<T> m) { |
| return Vec512<T>{_mm512_movm_epi8(m.raw)}; |
| } |
| template <typename T, HWY_IF_UI16(T)> |
| HWY_API Vec512<T> VecFromMask(Mask512<T> m) { |
| return Vec512<T>{_mm512_movm_epi16(m.raw)}; |
| } |
| #if HWY_HAVE_FLOAT16 |
| HWY_API Vec512<float16_t> VecFromMask(Mask512<float16_t> m) { |
| return Vec512<float16_t>{_mm512_castsi512_ph(_mm512_movm_epi16(m.raw))}; |
| } |
| #endif // HWY_HAVE_FLOAT16 |
| template <typename T, HWY_IF_UI32(T)> |
| HWY_API Vec512<T> VecFromMask(Mask512<T> m) { |
| return Vec512<T>{_mm512_movm_epi32(m.raw)}; |
| } |
| template <typename T, HWY_IF_UI64(T)> |
| HWY_API Vec512<T> VecFromMask(Mask512<T> m) { |
| return Vec512<T>{_mm512_movm_epi64(m.raw)}; |
| } |
| template <typename T, HWY_IF_FLOAT_OR_SPECIAL(T)> |
| HWY_API Vec512<T> VecFromMask(Mask512<T> m) { |
| const Full512<T> d; |
| const Full512<MakeSigned<T>> di; |
| return BitCast(d, VecFromMask(RebindMask(di, m))); |
| } |
| |
| // ------------------------------ Mask logical |
| |
| namespace detail { |
| |
| template <typename T> |
| HWY_INLINE Mask512<T> Not(hwy::SizeTag<1> /*tag*/, Mask512<T> m) { |
| #if HWY_COMPILER_HAS_MASK_INTRINSICS |
| return Mask512<T>{_knot_mask64(m.raw)}; |
| #else |
| return Mask512<T>{~m.raw}; |
| #endif |
| } |
| template <typename T> |
| HWY_INLINE Mask512<T> Not(hwy::SizeTag<2> /*tag*/, Mask512<T> m) { |
| #if HWY_COMPILER_HAS_MASK_INTRINSICS |
| return Mask512<T>{_knot_mask32(m.raw)}; |
| #else |
| return Mask512<T>{~m.raw}; |
| #endif |
| } |
| template <typename T> |
| HWY_INLINE Mask512<T> Not(hwy::SizeTag<4> /*tag*/, Mask512<T> m) { |
| #if HWY_COMPILER_HAS_MASK_INTRINSICS |
| return Mask512<T>{_knot_mask16(m.raw)}; |
| #else |
| return Mask512<T>{static_cast<uint16_t>(~m.raw & 0xFFFF)}; |
| #endif |
| } |
| template <typename T> |
| HWY_INLINE Mask512<T> Not(hwy::SizeTag<8> /*tag*/, Mask512<T> m) { |
| #if HWY_COMPILER_HAS_MASK_INTRINSICS |
| return Mask512<T>{_knot_mask8(m.raw)}; |
| #else |
| return Mask512<T>{static_cast<uint8_t>(~m.raw & 0xFF)}; |
| #endif |
| } |
| |
| template <typename T> |
| HWY_INLINE Mask512<T> And(hwy::SizeTag<1> /*tag*/, Mask512<T> a, Mask512<T> b) { |
| #if HWY_COMPILER_HAS_MASK_INTRINSICS |
| return Mask512<T>{_kand_mask64(a.raw, b.raw)}; |
| #else |
| return Mask512<T>{a.raw & b.raw}; |
| #endif |
| } |
| template <typename T> |
| HWY_INLINE Mask512<T> And(hwy::SizeTag<2> /*tag*/, Mask512<T> a, Mask512<T> b) { |
| #if HWY_COMPILER_HAS_MASK_INTRINSICS |
| return Mask512<T>{_kand_mask32(a.raw, b.raw)}; |
| #else |
| return Mask512<T>{a.raw & b.raw}; |
| #endif |
| } |
| template <typename T> |
| HWY_INLINE Mask512<T> And(hwy::SizeTag<4> /*tag*/, Mask512<T> a, Mask512<T> b) { |
| #if HWY_COMPILER_HAS_MASK_INTRINSICS |
| return Mask512<T>{_kand_mask16(a.raw, b.raw)}; |
| #else |
| return Mask512<T>{static_cast<uint16_t>(a.raw & b.raw)}; |
| #endif |
| } |
| template <typename T> |
| HWY_INLINE Mask512<T> And(hwy::SizeTag<8> /*tag*/, Mask512<T> a, Mask512<T> b) { |
| #if HWY_COMPILER_HAS_MASK_INTRINSICS |
| return Mask512<T>{_kand_mask8(a.raw, b.raw)}; |
| #else |
| return Mask512<T>{static_cast<uint8_t>(a.raw & b.raw)}; |
| #endif |
| } |
| |
| template <typename T> |
| HWY_INLINE Mask512<T> AndNot(hwy::SizeTag<1> /*tag*/, Mask512<T> a, |
| Mask512<T> b) { |
| #if HWY_COMPILER_HAS_MASK_INTRINSICS |
| return Mask512<T>{_kandn_mask64(a.raw, b.raw)}; |
| #else |
| return Mask512<T>{~a.raw & b.raw}; |
| #endif |
| } |
| template <typename T> |
| HWY_INLINE Mask512<T> AndNot(hwy::SizeTag<2> /*tag*/, Mask512<T> a, |
| Mask512<T> b) { |
| #if HWY_COMPILER_HAS_MASK_INTRINSICS |
| return Mask512<T>{_kandn_mask32(a.raw, b.raw)}; |
| #else |
| return Mask512<T>{~a.raw & b.raw}; |
| #endif |
| } |
| template <typename T> |
| HWY_INLINE Mask512<T> AndNot(hwy::SizeTag<4> /*tag*/, Mask512<T> a, |
| Mask512<T> b) { |
| #if HWY_COMPILER_HAS_MASK_INTRINSICS |
| return Mask512<T>{_kandn_mask16(a.raw, b.raw)}; |
| #else |
| return Mask512<T>{static_cast<uint16_t>(~a.raw & b.raw)}; |
| #endif |
| } |
| template <typename T> |
| HWY_INLINE Mask512<T> AndNot(hwy::SizeTag<8> /*tag*/, Mask512<T> a, |
| Mask512<T> b) { |
| #if HWY_COMPILER_HAS_MASK_INTRINSICS |
| return Mask512<T>{_kandn_mask8(a.raw, b.raw)}; |
| #else |
| return Mask512<T>{static_cast<uint8_t>(~a.raw & b.raw)}; |
| #endif |
| } |
| |
| template <typename T> |
| HWY_INLINE Mask512<T> Or(hwy::SizeTag<1> /*tag*/, Mask512<T> a, Mask512<T> b) { |
| #if HWY_COMPILER_HAS_MASK_INTRINSICS |
| return Mask512<T>{_kor_mask64(a.raw, b.raw)}; |
| #else |
| return Mask512<T>{a.raw | b.raw}; |
| #endif |
| } |
| template <typename T> |
| HWY_INLINE Mask512<T> Or(hwy::SizeTag<2> /*tag*/, Mask512<T> a, Mask512<T> b) { |
| #if HWY_COMPILER_HAS_MASK_INTRINSICS |
| return Mask512<T>{_kor_mask32(a.raw, b.raw)}; |
| #else |
| return Mask512<T>{a.raw | b.raw}; |
| #endif |
| } |
| template <typename T> |
| HWY_INLINE Mask512<T> Or(hwy::SizeTag<4> /*tag*/, Mask512<T> a, Mask512<T> b) { |
| #if HWY_COMPILER_HAS_MASK_INTRINSICS |
| return Mask512<T>{_kor_mask16(a.raw, b.raw)}; |
| #else |
| return Mask512<T>{static_cast<uint16_t>(a.raw | b.raw)}; |
| #endif |
| } |
| template <typename T> |
| HWY_INLINE Mask512<T> Or(hwy::SizeTag<8> /*tag*/, Mask512<T> a, Mask512<T> b) { |
| #if HWY_COMPILER_HAS_MASK_INTRINSICS |
| return Mask512<T>{_kor_mask8(a.raw, b.raw)}; |
| #else |
| return Mask512<T>{static_cast<uint8_t>(a.raw | b.raw)}; |
| #endif |
| } |
| |
| template <typename T> |
| HWY_INLINE Mask512<T> Xor(hwy::SizeTag<1> /*tag*/, Mask512<T> a, Mask512<T> b) { |
| #if HWY_COMPILER_HAS_MASK_INTRINSICS |
| return Mask512<T>{_kxor_mask64(a.raw, b.raw)}; |
| #else |
| return Mask512<T>{a.raw ^ b.raw}; |
| #endif |
| } |
| template <typename T> |
| HWY_INLINE Mask512<T> Xor(hwy::SizeTag<2> /*tag*/, Mask512<T> a, Mask512<T> b) { |
| #if HWY_COMPILER_HAS_MASK_INTRINSICS |
| return Mask512<T>{_kxor_mask32(a.raw, b.raw)}; |
| #else |
| return Mask512<T>{a.raw ^ b.raw}; |
| #endif |
| } |
| template <typename T> |
| HWY_INLINE Mask512<T> Xor(hwy::SizeTag<4> /*tag*/, Mask512<T> a, Mask512<T> b) { |
| #if HWY_COMPILER_HAS_MASK_INTRINSICS |
| return Mask512<T>{_kxor_mask16(a.raw, b.raw)}; |
| #else |
| return Mask512<T>{static_cast<uint16_t>(a.raw ^ b.raw)}; |
| #endif |
| } |
| template <typename T> |
| HWY_INLINE Mask512<T> Xor(hwy::SizeTag<8> /*tag*/, Mask512<T> a, Mask512<T> b) { |
| #if HWY_COMPILER_HAS_MASK_INTRINSICS |
| return Mask512<T>{_kxor_mask8(a.raw, b.raw)}; |
| #else |
| return Mask512<T>{static_cast<uint8_t>(a.raw ^ b.raw)}; |
| #endif |
| } |
| |
| template <typename T> |
| HWY_INLINE Mask512<T> ExclusiveNeither(hwy::SizeTag<1> /*tag*/, Mask512<T> a, |
| Mask512<T> b) { |
| #if HWY_COMPILER_HAS_MASK_INTRINSICS |
| return Mask512<T>{_kxnor_mask64(a.raw, b.raw)}; |
| #else |
| return Mask512<T>{~(a.raw ^ b.raw)}; |
| #endif |
| } |
| template <typename T> |
| HWY_INLINE Mask512<T> ExclusiveNeither(hwy::SizeTag<2> /*tag*/, Mask512<T> a, |
| Mask512<T> b) { |
| #if HWY_COMPILER_HAS_MASK_INTRINSICS |
| return Mask512<T>{_kxnor_mask32(a.raw, b.raw)}; |
| #else |
| return Mask512<T>{static_cast<__mmask32>(~(a.raw ^ b.raw) & 0xFFFFFFFF)}; |
| #endif |
| } |
| template <typename T> |
| HWY_INLINE Mask512<T> ExclusiveNeither(hwy::SizeTag<4> /*tag*/, Mask512<T> a, |
| Mask512<T> b) { |
| #if HWY_COMPILER_HAS_MASK_INTRINSICS |
| return Mask512<T>{_kxnor_mask16(a.raw, b.raw)}; |
| #else |
| return Mask512<T>{static_cast<__mmask16>(~(a.raw ^ b.raw) & 0xFFFF)}; |
| #endif |
| } |
| template <typename T> |
| HWY_INLINE Mask512<T> ExclusiveNeither(hwy::SizeTag<8> /*tag*/, Mask512<T> a, |
| Mask512<T> b) { |
| #if HWY_COMPILER_HAS_MASK_INTRINSICS |
| return Mask512<T>{_kxnor_mask8(a.raw, b.raw)}; |
| #else |
| return Mask512<T>{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0xFF)}; |
| #endif |
| } |
| |
| } // namespace detail |
| |
| template <typename T> |
| HWY_API Mask512<T> Not(Mask512<T> m) { |
| return detail::Not(hwy::SizeTag<sizeof(T)>(), m); |
| } |
| |
| template <typename T> |
| HWY_API Mask512<T> And(Mask512<T> a, Mask512<T> b) { |
| return detail::And(hwy::SizeTag<sizeof(T)>(), a, b); |
| } |
| |
| template <typename T> |
| HWY_API Mask512<T> AndNot(Mask512<T> a, Mask512<T> b) { |
| return detail::AndNot(hwy::SizeTag<sizeof(T)>(), a, b); |
| } |
| |
| template <typename T> |
| HWY_API Mask512<T> Or(Mask512<T> a, Mask512<T> b) { |
| return detail::Or(hwy::SizeTag<sizeof(T)>(), a, b); |
| } |
| |
| template <typename T> |
| HWY_API Mask512<T> Xor(Mask512<T> a, Mask512<T> b) { |
| return detail::Xor(hwy::SizeTag<sizeof(T)>(), a, b); |
| } |
| |
| template <typename T> |
| HWY_API Mask512<T> ExclusiveNeither(Mask512<T> a, Mask512<T> b) { |
| return detail::ExclusiveNeither(hwy::SizeTag<sizeof(T)>(), a, b); |
| } |
| |
| template <class D, HWY_IF_LANES_D(D, 64)> |
| HWY_API MFromD<D> CombineMasks(D /*d*/, MFromD<Half<D>> hi, |
| MFromD<Half<D>> lo) { |
| #if HWY_COMPILER_HAS_MASK_INTRINSICS |
| const __mmask64 combined_mask = _mm512_kunpackd( |
| static_cast<__mmask64>(hi.raw), static_cast<__mmask64>(lo.raw)); |
| #else |
| const __mmask64 combined_mask = static_cast<__mmask64>( |
| ((static_cast<uint64_t>(hi.raw) << 32) | (lo.raw & 0xFFFFFFFFULL))); |
| #endif |
| |
| return MFromD<D>{combined_mask}; |
| } |
| |
| template <class D, HWY_IF_LANES_D(D, 32)> |
| HWY_API MFromD<D> UpperHalfOfMask(D /*d*/, MFromD<Twice<D>> m) { |
| #if HWY_COMPILER_HAS_MASK_INTRINSICS |
| const auto shifted_mask = _kshiftri_mask64(static_cast<__mmask64>(m.raw), 32); |
| #else |
| const auto shifted_mask = static_cast<uint64_t>(m.raw) >> 32; |
| #endif |
| |
| return MFromD<D>{static_cast<decltype(MFromD<D>().raw)>(shifted_mask)}; |
| } |
| |
| template <class D, HWY_IF_LANES_D(D, 64)> |
| HWY_API MFromD<D> SlideMask1Up(D /*d*/, MFromD<D> m) { |
| using RawM = decltype(MFromD<D>().raw); |
| #if HWY_COMPILER_HAS_MASK_INTRINSICS |
| return MFromD<D>{ |
| static_cast<RawM>(_kshiftli_mask64(static_cast<__mmask64>(m.raw), 1))}; |
| #else |
| return MFromD<D>{static_cast<RawM>(static_cast<uint64_t>(m.raw) << 1)}; |
| #endif |
| } |
| |
| template <class D, HWY_IF_LANES_D(D, 64)> |
| HWY_API MFromD<D> SlideMask1Down(D /*d*/, MFromD<D> m) { |
| using RawM = decltype(MFromD<D>().raw); |
| #if HWY_COMPILER_HAS_MASK_INTRINSICS |
| return MFromD<D>{ |
| static_cast<RawM>(_kshiftri_mask64(static_cast<__mmask64>(m.raw), 1))}; |
| #else |
| return MFromD<D>{static_cast<RawM>(static_cast<uint64_t>(m.raw) >> 1)}; |
| #endif |
| } |
| |
| // ------------------------------ BroadcastSignBit (ShiftRight, compare, mask) |
| |
| HWY_API Vec512<int8_t> BroadcastSignBit(Vec512<int8_t> v) { |
| #if HWY_TARGET <= HWY_AVX3_DL |
| const Repartition<uint64_t, DFromV<decltype(v)>> du64; |
| return detail::GaloisAffine(v, Set(du64, 0x8080808080808080ull)); |
| #else |
| const DFromV<decltype(v)> d; |
| return VecFromMask(v < Zero(d)); |
| #endif |
| } |
| |
| HWY_API Vec512<int16_t> BroadcastSignBit(Vec512<int16_t> v) { |
| return ShiftRight<15>(v); |
| } |
| |
| HWY_API Vec512<int32_t> BroadcastSignBit(Vec512<int32_t> v) { |
| return ShiftRight<31>(v); |
| } |
| |
| HWY_API Vec512<int64_t> BroadcastSignBit(Vec512<int64_t> v) { |
| return ShiftRight<63>(v); |
| } |
| |
| // ------------------------------ Floating-point classification (Not) |
| |
| #if HWY_HAVE_FLOAT16 || HWY_IDE |
| |
| namespace detail { |
| |
| template <int kCategories> |
| __mmask32 Fix_mm512_fpclass_ph_mask(__m512h v) { |
| #if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1500 |
| // GCC's _mm512_cmp_ph_mask uses `__mmask8` instead of `__mmask32`, hence only |
| // the first 8 lanes are set. |
| return static_cast<__mmask32>(__builtin_ia32_fpclassph512_mask( |
| static_cast<__v32hf>(v), kCategories, static_cast<__mmask32>(-1))); |
| #else |
| return _mm512_fpclass_ph_mask(v, kCategories); |
| #endif |
| } |
| |
| } // namespace detail |
| |
| HWY_API Mask512<float16_t> IsNaN(Vec512<float16_t> v) { |
| constexpr int kCategories = HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN; |
| return Mask512<float16_t>{ |
| detail::Fix_mm512_fpclass_ph_mask<kCategories>(v.raw)}; |
| } |
| |
| HWY_API Mask512<float16_t> IsEitherNaN(Vec512<float16_t> a, |
| Vec512<float16_t> b) { |
| // Work around warnings in the intrinsic definitions (passing -1 as a mask). |
| HWY_DIAGNOSTICS(push) |
| HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") |
| return Mask512<float16_t>{_mm512_cmp_ph_mask(a.raw, b.raw, _CMP_UNORD_Q)}; |
| HWY_DIAGNOSTICS(pop) |
| } |
| |
| HWY_API Mask512<float16_t> IsInf(Vec512<float16_t> v) { |
| constexpr int kCategories = HWY_X86_FPCLASS_POS_INF | HWY_X86_FPCLASS_NEG_INF; |
| return Mask512<float16_t>{ |
| detail::Fix_mm512_fpclass_ph_mask<kCategories>(v.raw)}; |
| } |
| |
| // Returns whether normal/subnormal/zero. fpclass doesn't have a flag for |
| // positive, so we have to check for inf/NaN and negate. |
| HWY_API Mask512<float16_t> IsFinite(Vec512<float16_t> v) { |
| constexpr int kCategories = HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN | |
| HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF; |
| return Not(Mask512<float16_t>{ |
| detail::Fix_mm512_fpclass_ph_mask<kCategories>(v.raw)}); |
| } |
| |
| #endif // HWY_HAVE_FLOAT16 |
| |
| HWY_API Mask512<float> IsNaN(Vec512<float> v) { |
| return Mask512<float>{_mm512_fpclass_ps_mask( |
| v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN)}; |
| } |
| HWY_API Mask512<double> IsNaN(Vec512<double> v) { |
| return Mask512<double>{_mm512_fpclass_pd_mask( |
| v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN)}; |
| } |
| |
| HWY_API Mask512<float> IsEitherNaN(Vec512<float> a, Vec512<float> b) { |
| return Mask512<float>{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_UNORD_Q)}; |
| } |
| |
| HWY_API Mask512<double> IsEitherNaN(Vec512<double> a, Vec512<double> b) { |
| return Mask512<double>{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_UNORD_Q)}; |
| } |
| |
| HWY_API Mask512<float> IsInf(Vec512<float> v) { |
| return Mask512<float>{_mm512_fpclass_ps_mask( |
| v.raw, HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}; |
| } |
| HWY_API Mask512<double> IsInf(Vec512<double> v) { |
| return Mask512<double>{_mm512_fpclass_pd_mask( |
| v.raw, HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}; |
| } |
| |
| // Returns whether normal/subnormal/zero. fpclass doesn't have a flag for |
| // positive, so we have to check for inf/NaN and negate. |
| HWY_API Mask512<float> IsFinite(Vec512<float> v) { |
| return Not(Mask512<float>{_mm512_fpclass_ps_mask( |
| v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN | |
| HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}); |
| } |
| HWY_API Mask512<double> IsFinite(Vec512<double> v) { |
| return Not(Mask512<double>{_mm512_fpclass_pd_mask( |
| v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN | |
| HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}); |
| } |
| |
| // ================================================== MEMORY |
| |
| // ------------------------------ Load |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(D)> |
| HWY_API VFromD<D> Load(D /* tag */, const TFromD<D>* HWY_RESTRICT aligned) { |
| return VFromD<D>{_mm512_load_si512(aligned)}; |
| } |
| // bfloat16_t is handled by x86_128-inl.h. |
| #if HWY_HAVE_FLOAT16 |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F16_D(D)> |
| HWY_API Vec512<float16_t> Load(D /* tag */, |
| const float16_t* HWY_RESTRICT aligned) { |
| return Vec512<float16_t>{_mm512_load_ph(aligned)}; |
| } |
| #endif // HWY_HAVE_FLOAT16 |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> |
| HWY_API Vec512<float> Load(D /* tag */, const float* HWY_RESTRICT aligned) { |
| return Vec512<float>{_mm512_load_ps(aligned)}; |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> |
| HWY_API VFromD<D> Load(D /* tag */, const double* HWY_RESTRICT aligned) { |
| return VFromD<D>{_mm512_load_pd(aligned)}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(D)> |
| HWY_API VFromD<D> LoadU(D /* tag */, const TFromD<D>* HWY_RESTRICT p) { |
| return VFromD<D>{_mm512_loadu_si512(p)}; |
| } |
| |
| // bfloat16_t is handled by x86_128-inl.h. |
| #if HWY_HAVE_FLOAT16 |
| template <class D, HWY_IF_V_SIZE_D(D, 64)> |
| HWY_API Vec512<float16_t> LoadU(D /* tag */, const float16_t* HWY_RESTRICT p) { |
| return Vec512<float16_t>{_mm512_loadu_ph(p)}; |
| } |
| #endif // HWY_HAVE_FLOAT16 |
| template <class D, HWY_IF_V_SIZE_D(D, 64)> |
| HWY_API Vec512<float> LoadU(D /* tag */, const float* HWY_RESTRICT p) { |
| return Vec512<float>{_mm512_loadu_ps(p)}; |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> |
| HWY_API VFromD<D> LoadU(D /* tag */, const double* HWY_RESTRICT p) { |
| return VFromD<D>{_mm512_loadu_pd(p)}; |
| } |
| |
| // ------------------------------ MaskedLoad |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 1)> |
| HWY_API VFromD<D> MaskedLoad(MFromD<D> m, D /* tag */, |
| const TFromD<D>* HWY_RESTRICT p) { |
| return VFromD<D>{_mm512_maskz_loadu_epi8(m.raw, p)}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 2)> |
| HWY_API VFromD<D> MaskedLoad(MFromD<D> m, D d, |
| const TFromD<D>* HWY_RESTRICT p) { |
| const RebindToUnsigned<D> du; // for float16_t |
| return BitCast(d, VFromD<decltype(du)>{_mm512_maskz_loadu_epi16( |
| m.raw, reinterpret_cast<const uint16_t*>(p))}); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI32_D(D)> |
| HWY_API VFromD<D> MaskedLoad(MFromD<D> m, D /* tag */, |
| const TFromD<D>* HWY_RESTRICT p) { |
| return VFromD<D>{_mm512_maskz_loadu_epi32(m.raw, p)}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI64_D(D)> |
| HWY_API VFromD<D> MaskedLoad(MFromD<D> m, D /* tag */, |
| const TFromD<D>* HWY_RESTRICT p) { |
| return VFromD<D>{_mm512_maskz_loadu_epi64(m.raw, p)}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> |
| HWY_API Vec512<float> MaskedLoad(Mask512<float> m, D /* tag */, |
| const float* HWY_RESTRICT p) { |
| return Vec512<float>{_mm512_maskz_loadu_ps(m.raw, p)}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> |
| HWY_API Vec512<double> MaskedLoad(Mask512<double> m, D /* tag */, |
| const double* HWY_RESTRICT p) { |
| return Vec512<double>{_mm512_maskz_loadu_pd(m.raw, p)}; |
| } |
| |
| // ------------------------------ MaskedLoadOr |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 1)> |
| HWY_API VFromD<D> MaskedLoadOr(VFromD<D> v, MFromD<D> m, D /* tag */, |
| const TFromD<D>* HWY_RESTRICT p) { |
| return VFromD<D>{_mm512_mask_loadu_epi8(v.raw, m.raw, p)}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 2)> |
| HWY_API VFromD<D> MaskedLoadOr(VFromD<D> v, MFromD<D> m, D d, |
| const TFromD<D>* HWY_RESTRICT p) { |
| const RebindToUnsigned<decltype(d)> du; // for float16_t |
| return BitCast( |
| d, VFromD<decltype(du)>{_mm512_mask_loadu_epi16( |
| BitCast(du, v).raw, m.raw, reinterpret_cast<const uint16_t*>(p))}); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI32_D(D)> |
| HWY_API VFromD<D> MaskedLoadOr(VFromD<D> v, MFromD<D> m, D /* tag */, |
| const TFromD<D>* HWY_RESTRICT p) { |
| return VFromD<D>{_mm512_mask_loadu_epi32(v.raw, m.raw, p)}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI64_D(D)> |
| HWY_API VFromD<D> MaskedLoadOr(VFromD<D> v, MFromD<D> m, D /* tag */, |
| const TFromD<D>* HWY_RESTRICT p) { |
| return VFromD<D>{_mm512_mask_loadu_epi64(v.raw, m.raw, p)}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> |
| HWY_API VFromD<D> MaskedLoadOr(VFromD<D> v, Mask512<float> m, D /* tag */, |
| const float* HWY_RESTRICT p) { |
| return VFromD<D>{_mm512_mask_loadu_ps(v.raw, m.raw, p)}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> |
| HWY_API VFromD<D> MaskedLoadOr(VFromD<D> v, Mask512<double> m, D /* tag */, |
| const double* HWY_RESTRICT p) { |
| return VFromD<D>{_mm512_mask_loadu_pd(v.raw, m.raw, p)}; |
| } |
| |
| // ------------------------------ LoadDup128 |
| |
| // Loads 128 bit and duplicates into both 128-bit halves. This avoids the |
| // 3-cycle cost of moving data between 128-bit halves and avoids port 5. |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_FLOAT3264_D(D)> |
| HWY_API VFromD<D> LoadDup128(D d, const TFromD<D>* const HWY_RESTRICT p) { |
| const RebindToUnsigned<decltype(d)> du; |
| const Full128<TFromD<D>> d128; |
| const RebindToUnsigned<decltype(d128)> du128; |
| return BitCast(d, VFromD<decltype(du)>{_mm512_broadcast_i32x4( |
| BitCast(du128, LoadU(d128, p)).raw)}); |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> |
| HWY_API VFromD<D> LoadDup128(D /* tag */, const float* HWY_RESTRICT p) { |
| const __m128 x4 = _mm_loadu_ps(p); |
| return VFromD<D>{_mm512_broadcast_f32x4(x4)}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> |
| HWY_API VFromD<D> LoadDup128(D /* tag */, const double* HWY_RESTRICT p) { |
| const __m128d x2 = _mm_loadu_pd(p); |
| return VFromD<D>{_mm512_broadcast_f64x2(x2)}; |
| } |
| |
| // ------------------------------ Store |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(D)> |
| HWY_API void Store(VFromD<D> v, D /* tag */, TFromD<D>* HWY_RESTRICT aligned) { |
| _mm512_store_si512(reinterpret_cast<__m512i*>(aligned), v.raw); |
| } |
| // bfloat16_t is handled by x86_128-inl.h. |
| #if HWY_HAVE_FLOAT16 |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F16_D(D)> |
| HWY_API void Store(Vec512<float16_t> v, D /* tag */, |
| float16_t* HWY_RESTRICT aligned) { |
| _mm512_store_ph(aligned, v.raw); |
| } |
| #endif |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> |
| HWY_API void Store(Vec512<float> v, D /* tag */, float* HWY_RESTRICT aligned) { |
| _mm512_store_ps(aligned, v.raw); |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> |
| HWY_API void Store(VFromD<D> v, D /* tag */, double* HWY_RESTRICT aligned) { |
| _mm512_store_pd(aligned, v.raw); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(D)> |
| HWY_API void StoreU(VFromD<D> v, D /* tag */, TFromD<D>* HWY_RESTRICT p) { |
| _mm512_storeu_si512(reinterpret_cast<__m512i*>(p), v.raw); |
| } |
| // bfloat16_t is handled by x86_128-inl.h. |
| #if HWY_HAVE_FLOAT16 |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F16_D(D)> |
| HWY_API void StoreU(Vec512<float16_t> v, D /* tag */, |
| float16_t* HWY_RESTRICT p) { |
| _mm512_storeu_ph(p, v.raw); |
| } |
| #endif // HWY_HAVE_FLOAT16 |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> |
| HWY_API void StoreU(Vec512<float> v, D /* tag */, float* HWY_RESTRICT p) { |
| _mm512_storeu_ps(p, v.raw); |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> |
| HWY_API void StoreU(Vec512<double> v, D /* tag */, double* HWY_RESTRICT p) { |
| _mm512_storeu_pd(p, v.raw); |
| } |
| |
| // ------------------------------ BlendedStore |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 1)> |
| HWY_API void BlendedStore(VFromD<D> v, MFromD<D> m, D /* tag */, |
| TFromD<D>* HWY_RESTRICT p) { |
| _mm512_mask_storeu_epi8(p, m.raw, v.raw); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 2)> |
| HWY_API void BlendedStore(VFromD<D> v, MFromD<D> m, D d, |
| TFromD<D>* HWY_RESTRICT p) { |
| const RebindToUnsigned<decltype(d)> du; // for float16_t |
| _mm512_mask_storeu_epi16(reinterpret_cast<uint16_t*>(p), m.raw, |
| BitCast(du, v).raw); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI32_D(D)> |
| HWY_API void BlendedStore(VFromD<D> v, MFromD<D> m, D /* tag */, |
| TFromD<D>* HWY_RESTRICT p) { |
| _mm512_mask_storeu_epi32(p, m.raw, v.raw); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI64_D(D)> |
| HWY_API void BlendedStore(VFromD<D> v, MFromD<D> m, D /* tag */, |
| TFromD<D>* HWY_RESTRICT p) { |
| _mm512_mask_storeu_epi64(p, m.raw, v.raw); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> |
| HWY_API void BlendedStore(Vec512<float> v, Mask512<float> m, D /* tag */, |
| float* HWY_RESTRICT p) { |
| _mm512_mask_storeu_ps(p, m.raw, v.raw); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> |
| HWY_API void BlendedStore(Vec512<double> v, Mask512<double> m, D /* tag */, |
| double* HWY_RESTRICT p) { |
| _mm512_mask_storeu_pd(p, m.raw, v.raw); |
| } |
| |
| // ------------------------------ Non-temporal stores |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_FLOAT3264_D(D)> |
| HWY_API void Stream(VFromD<D> v, D d, TFromD<D>* HWY_RESTRICT aligned) { |
| const RebindToUnsigned<decltype(d)> du; // for float16_t |
| _mm512_stream_si512(reinterpret_cast<__m512i*>(aligned), BitCast(du, v).raw); |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> |
| HWY_API void Stream(VFromD<D> v, D /* tag */, float* HWY_RESTRICT aligned) { |
| _mm512_stream_ps(aligned, v.raw); |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> |
| HWY_API void Stream(VFromD<D> v, D /* tag */, double* HWY_RESTRICT aligned) { |
| _mm512_stream_pd(aligned, v.raw); |
| } |
| |
| // ------------------------------ ScatterOffset |
| |
| // Work around warnings in the intrinsic definitions (passing -1 as a mask). |
| HWY_DIAGNOSTICS(push) |
| HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI32_D(D)> |
| HWY_API void ScatterOffset(VFromD<D> v, D /* tag */, |
| TFromD<D>* HWY_RESTRICT base, |
| VFromD<RebindToSigned<D>> offset) { |
| _mm512_i32scatter_epi32(base, offset.raw, v.raw, 1); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI64_D(D)> |
| HWY_API void ScatterOffset(VFromD<D> v, D /* tag */, |
| TFromD<D>* HWY_RESTRICT base, |
| VFromD<RebindToSigned<D>> offset) { |
| _mm512_i64scatter_epi64(base, offset.raw, v.raw, 1); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> |
| HWY_API void ScatterOffset(VFromD<D> v, D /* tag */, float* HWY_RESTRICT base, |
| Vec512<int32_t> offset) { |
| _mm512_i32scatter_ps(base, offset.raw, v.raw, 1); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> |
| HWY_API void ScatterOffset(VFromD<D> v, D /* tag */, double* HWY_RESTRICT base, |
| Vec512<int64_t> offset) { |
| _mm512_i64scatter_pd(base, offset.raw, v.raw, 1); |
| } |
| |
| // ------------------------------ ScatterIndex |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI32_D(D)> |
| HWY_API void ScatterIndex(VFromD<D> v, D /* tag */, |
| TFromD<D>* HWY_RESTRICT base, |
| VFromD<RebindToSigned<D>> index) { |
| _mm512_i32scatter_epi32(base, index.raw, v.raw, 4); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI64_D(D)> |
| HWY_API void ScatterIndex(VFromD<D> v, D /* tag */, |
| TFromD<D>* HWY_RESTRICT base, |
| VFromD<RebindToSigned<D>> index) { |
| _mm512_i64scatter_epi64(base, index.raw, v.raw, 8); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> |
| HWY_API void ScatterIndex(VFromD<D> v, D /* tag */, float* HWY_RESTRICT base, |
| Vec512<int32_t> index) { |
| _mm512_i32scatter_ps(base, index.raw, v.raw, 4); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> |
| HWY_API void ScatterIndex(VFromD<D> v, D /* tag */, double* HWY_RESTRICT base, |
| Vec512<int64_t> index) { |
| _mm512_i64scatter_pd(base, index.raw, v.raw, 8); |
| } |
| |
| // ------------------------------ MaskedScatterIndex |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI32_D(D)> |
| HWY_API void MaskedScatterIndex(VFromD<D> v, MFromD<D> m, D /* tag */, |
| TFromD<D>* HWY_RESTRICT base, |
| VFromD<RebindToSigned<D>> index) { |
| _mm512_mask_i32scatter_epi32(base, m.raw, index.raw, v.raw, 4); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI64_D(D)> |
| HWY_API void MaskedScatterIndex(VFromD<D> v, MFromD<D> m, D /* tag */, |
| TFromD<D>* HWY_RESTRICT base, |
| VFromD<RebindToSigned<D>> index) { |
| _mm512_mask_i64scatter_epi64(base, m.raw, index.raw, v.raw, 8); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> |
| HWY_API void MaskedScatterIndex(VFromD<D> v, MFromD<D> m, D /* tag */, |
| float* HWY_RESTRICT base, |
| Vec512<int32_t> index) { |
| _mm512_mask_i32scatter_ps(base, m.raw, index.raw, v.raw, 4); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> |
| HWY_API void MaskedScatterIndex(VFromD<D> v, MFromD<D> m, D /* tag */, |
| double* HWY_RESTRICT base, |
| Vec512<int64_t> index) { |
| _mm512_mask_i64scatter_pd(base, m.raw, index.raw, v.raw, 8); |
| } |
| |
| // ------------------------------ Gather |
| |
| namespace detail { |
| |
| template <int kScale, typename T, HWY_IF_UI32(T)> |
| HWY_INLINE Vec512<T> NativeGather512(const T* HWY_RESTRICT base, |
| Vec512<int32_t> indices) { |
| return Vec512<T>{_mm512_i32gather_epi32(indices.raw, base, kScale)}; |
| } |
| |
| template <int kScale, typename T, HWY_IF_UI64(T)> |
| HWY_INLINE Vec512<T> NativeGather512(const T* HWY_RESTRICT base, |
| Vec512<int64_t> indices) { |
| return Vec512<T>{_mm512_i64gather_epi64(indices.raw, base, kScale)}; |
| } |
| |
| template <int kScale> |
| HWY_INLINE Vec512<float> NativeGather512(const float* HWY_RESTRICT base, |
| Vec512<int32_t> indices) { |
| return Vec512<float>{_mm512_i32gather_ps(indices.raw, base, kScale)}; |
| } |
| |
| template <int kScale> |
| HWY_INLINE Vec512<double> NativeGather512(const double* HWY_RESTRICT base, |
| Vec512<int64_t> indices) { |
| return Vec512<double>{_mm512_i64gather_pd(indices.raw, base, kScale)}; |
| } |
| |
| template <int kScale, typename T, HWY_IF_UI32(T)> |
| HWY_INLINE Vec512<T> NativeMaskedGatherOr512(Vec512<T> no, Mask512<T> m, |
| const T* HWY_RESTRICT base, |
| Vec512<int32_t> indices) { |
| return Vec512<T>{ |
| _mm512_mask_i32gather_epi32(no.raw, m.raw, indices.raw, base, kScale)}; |
| } |
| |
| template <int kScale, typename T, HWY_IF_UI64(T)> |
| HWY_INLINE Vec512<T> NativeMaskedGatherOr512(Vec512<T> no, Mask512<T> m, |
| const T* HWY_RESTRICT base, |
| Vec512<int64_t> indices) { |
| return Vec512<T>{ |
| _mm512_mask_i64gather_epi64(no.raw, m.raw, indices.raw, base, kScale)}; |
| } |
| |
| template <int kScale> |
| HWY_INLINE Vec512<float> NativeMaskedGatherOr512(Vec512<float> no, |
| Mask512<float> m, |
| const float* HWY_RESTRICT base, |
| Vec512<int32_t> indices) { |
| return Vec512<float>{ |
| _mm512_mask_i32gather_ps(no.raw, m.raw, indices.raw, base, kScale)}; |
| } |
| |
| template <int kScale> |
| HWY_INLINE Vec512<double> NativeMaskedGatherOr512( |
| Vec512<double> no, Mask512<double> m, const double* HWY_RESTRICT base, |
| Vec512<int64_t> indices) { |
| return Vec512<double>{ |
| _mm512_mask_i64gather_pd(no.raw, m.raw, indices.raw, base, kScale)}; |
| } |
| } // namespace detail |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64)> |
| HWY_API VFromD<D> GatherOffset(D /*d*/, const TFromD<D>* HWY_RESTRICT base, |
| VFromD<RebindToSigned<D>> offsets) { |
| return detail::NativeGather512<1>(base, offsets); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64)> |
| HWY_API VFromD<D> GatherIndex(D /*d*/, const TFromD<D>* HWY_RESTRICT base, |
| VFromD<RebindToSigned<D>> indices) { |
| return detail::NativeGather512<sizeof(TFromD<D>)>(base, indices); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64)> |
| HWY_API VFromD<D> MaskedGatherIndexOr(VFromD<D> no, MFromD<D> m, D /*d*/, |
| const TFromD<D>* HWY_RESTRICT base, |
| VFromD<RebindToSigned<D>> indices) { |
| return detail::NativeMaskedGatherOr512<sizeof(TFromD<D>)>(no, m, base, |
| indices); |
| } |
| |
| HWY_DIAGNOSTICS(pop) |
| |
| // ================================================== SWIZZLE |
| |
| // ------------------------------ LowerHalf |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(D)> |
| HWY_API VFromD<D> LowerHalf(D /* tag */, VFromD<Twice<D>> v) { |
| return VFromD<D>{_mm512_castsi512_si256(v.raw)}; |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_BF16_D(D)> |
| HWY_API VFromD<D> LowerHalf(D /* tag */, Vec512<bfloat16_t> v) { |
| return VFromD<D>{_mm512_castsi512_si256(v.raw)}; |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F16_D(D)> |
| HWY_API VFromD<D> LowerHalf(D /* tag */, Vec512<float16_t> v) { |
| #if HWY_HAVE_FLOAT16 |
| return VFromD<D>{_mm512_castph512_ph256(v.raw)}; |
| #else |
| return VFromD<D>{_mm512_castsi512_si256(v.raw)}; |
| #endif // HWY_HAVE_FLOAT16 |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> |
| HWY_API VFromD<D> LowerHalf(D /* tag */, Vec512<float> v) { |
| return VFromD<D>{_mm512_castps512_ps256(v.raw)}; |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)> |
| HWY_API VFromD<D> LowerHalf(D /* tag */, Vec512<double> v) { |
| return VFromD<D>{_mm512_castpd512_pd256(v.raw)}; |
| } |
| |
| template <typename T> |
| HWY_API Vec256<T> LowerHalf(Vec512<T> v) { |
| const Half<DFromV<decltype(v)>> dh; |
| return LowerHalf(dh, v); |
| } |
| |
| // ------------------------------ UpperHalf |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_NOT_FLOAT3264_D(D)> |
| HWY_API VFromD<D> UpperHalf(D d, VFromD<Twice<D>> v) { |
| const RebindToUnsigned<decltype(d)> du; // for float16_t |
| const Twice<decltype(du)> dut; |
| return BitCast(d, VFromD<decltype(du)>{ |
| _mm512_extracti32x8_epi32(BitCast(dut, v).raw, 1)}); |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> |
| HWY_API VFromD<D> UpperHalf(D /* tag */, VFromD<Twice<D>> v) { |
| return VFromD<D>{_mm512_extractf32x8_ps(v.raw, 1)}; |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)> |
| HWY_API VFromD<D> UpperHalf(D /* tag */, VFromD<Twice<D>> v) { |
| return VFromD<D>{_mm512_extractf64x4_pd(v.raw, 1)}; |
| } |
| |
| // ------------------------------ ExtractLane (Store) |
| template <typename T> |
| HWY_API T ExtractLane(const Vec512<T> v, size_t i) { |
| const DFromV<decltype(v)> d; |
| HWY_DASSERT(i < Lanes(d)); |
| |
| #if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang |
| constexpr size_t kLanesPerBlock = 16 / sizeof(T); |
| if (__builtin_constant_p(i < kLanesPerBlock) && (i < kLanesPerBlock)) { |
| return ExtractLane(ResizeBitCast(Full128<T>(), v), i); |
| } |
| #endif |
| |
| alignas(64) T lanes[MaxLanes(d)]; |
| Store(v, d, lanes); |
| return lanes[i]; |
| } |
| |
| // ------------------------------ ExtractBlock |
| template <int kBlockIdx, class T, hwy::EnableIf<(kBlockIdx <= 1)>* = nullptr> |
| HWY_API Vec128<T> ExtractBlock(Vec512<T> v) { |
| const DFromV<decltype(v)> d; |
| const Half<decltype(d)> dh; |
| return ExtractBlock<kBlockIdx>(LowerHalf(dh, v)); |
| } |
| |
| template <int kBlockIdx, class T, hwy::EnableIf<(kBlockIdx > 1)>* = nullptr> |
| HWY_API Vec128<T> ExtractBlock(Vec512<T> v) { |
| static_assert(kBlockIdx <= 3, "Invalid block index"); |
| const DFromV<decltype(v)> d; |
| const RebindToUnsigned<decltype(d)> du; // for float16_t |
| return BitCast(Full128<T>(), |
| Vec128<MakeUnsigned<T>>{ |
| _mm512_extracti32x4_epi32(BitCast(du, v).raw, kBlockIdx)}); |
| } |
| |
| template <int kBlockIdx, hwy::EnableIf<(kBlockIdx > 1)>* = nullptr> |
| HWY_API Vec128<float> ExtractBlock(Vec512<float> v) { |
| static_assert(kBlockIdx <= 3, "Invalid block index"); |
| return Vec128<float>{_mm512_extractf32x4_ps(v.raw, kBlockIdx)}; |
| } |
| |
| template <int kBlockIdx, hwy::EnableIf<(kBlockIdx > 1)>* = nullptr> |
| HWY_API Vec128<double> ExtractBlock(Vec512<double> v) { |
| static_assert(kBlockIdx <= 3, "Invalid block index"); |
| return Vec128<double>{_mm512_extractf64x2_pd(v.raw, kBlockIdx)}; |
| } |
| |
| // ------------------------------ InsertLane (Store) |
| template <typename T> |
| HWY_API Vec512<T> InsertLane(const Vec512<T> v, size_t i, T t) { |
| return detail::InsertLaneUsingBroadcastAndBlend(v, i, t); |
| } |
| |
| // ------------------------------ InsertBlock |
| namespace detail { |
| |
| template <typename T> |
| HWY_INLINE Vec512<T> InsertBlock(hwy::SizeTag<0> /* blk_idx_tag */, Vec512<T> v, |
| Vec128<T> blk_to_insert) { |
| const DFromV<decltype(v)> d; |
| const auto insert_mask = FirstN(d, 16 / sizeof(T)); |
| return IfThenElse(insert_mask, ResizeBitCast(d, blk_to_insert), v); |
| } |
| |
| template <size_t kBlockIdx, typename T> |
| HWY_INLINE Vec512<T> InsertBlock(hwy::SizeTag<kBlockIdx> /* blk_idx_tag */, |
| Vec512<T> v, Vec128<T> blk_to_insert) { |
| const DFromV<decltype(v)> d; |
| const RebindToUnsigned<decltype(d)> du; // for float16_t |
| const Full128<MakeUnsigned<T>> du_blk_to_insert; |
| return BitCast( |
| d, VFromD<decltype(du)>{_mm512_inserti32x4( |
| BitCast(du, v).raw, BitCast(du_blk_to_insert, blk_to_insert).raw, |
| static_cast<int>(kBlockIdx & 3))}); |
| } |
| |
| template <size_t kBlockIdx, hwy::EnableIf<kBlockIdx != 0>* = nullptr> |
| HWY_INLINE Vec512<float> InsertBlock(hwy::SizeTag<kBlockIdx> /* blk_idx_tag */, |
| Vec512<float> v, |
| Vec128<float> blk_to_insert) { |
| return Vec512<float>{_mm512_insertf32x4(v.raw, blk_to_insert.raw, |
| static_cast<int>(kBlockIdx & 3))}; |
| } |
| |
| template <size_t kBlockIdx, hwy::EnableIf<kBlockIdx != 0>* = nullptr> |
| HWY_INLINE Vec512<double> InsertBlock(hwy::SizeTag<kBlockIdx> /* blk_idx_tag */, |
| Vec512<double> v, |
| Vec128<double> blk_to_insert) { |
| return Vec512<double>{_mm512_insertf64x2(v.raw, blk_to_insert.raw, |
| static_cast<int>(kBlockIdx & 3))}; |
| } |
| |
| } // namespace detail |
| |
| template <int kBlockIdx, class T> |
| HWY_API Vec512<T> InsertBlock(Vec512<T> v, Vec128<T> blk_to_insert) { |
| static_assert(0 <= kBlockIdx && kBlockIdx <= 3, "Invalid block index"); |
| return detail::InsertBlock(hwy::SizeTag<static_cast<size_t>(kBlockIdx)>(), v, |
| blk_to_insert); |
| } |
| |
| // ------------------------------ GetLane (LowerHalf) |
| template <typename T> |
| HWY_API T GetLane(const Vec512<T> v) { |
| return GetLane(LowerHalf(v)); |
| } |
| |
| // ------------------------------ ZeroExtendVector |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(D)> |
| HWY_API VFromD<D> ZeroExtendVector(D d, VFromD<Half<D>> lo) { |
| #if HWY_HAVE_ZEXT // See definition/comment in x86_256-inl.h. |
| (void)d; |
| return VFromD<D>{_mm512_zextsi256_si512(lo.raw)}; |
| #else |
| return VFromD<D>{_mm512_inserti32x8(Zero(d).raw, lo.raw, 0)}; |
| #endif |
| } |
| #if HWY_HAVE_FLOAT16 |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F16_D(D)> |
| HWY_API VFromD<D> ZeroExtendVector(D d, VFromD<Half<D>> lo) { |
| #if HWY_HAVE_ZEXT |
| (void)d; |
| return VFromD<D>{_mm512_zextph256_ph512(lo.raw)}; |
| #else |
| const RebindToUnsigned<D> du; |
| return BitCast(d, ZeroExtendVector(du, BitCast(du, lo))); |
| #endif |
| } |
| #endif // HWY_HAVE_FLOAT16 |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> |
| HWY_API VFromD<D> ZeroExtendVector(D d, VFromD<Half<D>> lo) { |
| #if HWY_HAVE_ZEXT |
| (void)d; |
| return VFromD<D>{_mm512_zextps256_ps512(lo.raw)}; |
| #else |
| return VFromD<D>{_mm512_insertf32x8(Zero(d).raw, lo.raw, 0)}; |
| #endif |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> |
| HWY_API VFromD<D> ZeroExtendVector(D d, VFromD<Half<D>> lo) { |
| #if HWY_HAVE_ZEXT |
| (void)d; |
| return VFromD<D>{_mm512_zextpd256_pd512(lo.raw)}; |
| #else |
| return VFromD<D>{_mm512_insertf64x4(Zero(d).raw, lo.raw, 0)}; |
| #endif |
| } |
| |
| // ------------------------------ ZeroExtendResizeBitCast |
| |
| namespace detail { |
| |
| template <class DTo, class DFrom, HWY_IF_NOT_FLOAT3264_D(DTo)> |
| HWY_INLINE VFromD<DTo> ZeroExtendResizeBitCast( |
| hwy::SizeTag<16> /* from_size_tag */, hwy::SizeTag<64> /* to_size_tag */, |
| DTo d_to, DFrom d_from, VFromD<DFrom> v) { |
| const Repartition<uint8_t, decltype(d_from)> du8_from; |
| const auto vu8 = BitCast(du8_from, v); |
| const RebindToUnsigned<decltype(d_to)> du_to; |
| #if HWY_HAVE_ZEXT |
| return BitCast(d_to, |
| VFromD<decltype(du_to)>{_mm512_zextsi128_si512(vu8.raw)}); |
| #else |
| return BitCast(d_to, VFromD<decltype(du_to)>{ |
| _mm512_inserti32x4(Zero(du_to).raw, vu8.raw, 0)}); |
| #endif |
| } |
| |
| template <class DTo, class DFrom, HWY_IF_F32_D(DTo)> |
| HWY_INLINE VFromD<DTo> ZeroExtendResizeBitCast( |
| hwy::SizeTag<16> /* from_size_tag */, hwy::SizeTag<64> /* to_size_tag */, |
| DTo d_to, DFrom d_from, VFromD<DFrom> v) { |
| const Repartition<float, decltype(d_from)> df32_from; |
| const auto vf32 = BitCast(df32_from, v); |
| #if HWY_HAVE_ZEXT |
| (void)d_to; |
| return Vec512<float>{_mm512_zextps128_ps512(vf32.raw)}; |
| #else |
| return Vec512<float>{_mm512_insertf32x4(Zero(d_to).raw, vf32.raw, 0)}; |
| #endif |
| } |
| |
| template <class DTo, class DFrom, HWY_IF_F64_D(DTo)> |
| HWY_INLINE Vec512<double> ZeroExtendResizeBitCast( |
| hwy::SizeTag<16> /* from_size_tag */, hwy::SizeTag<64> /* to_size_tag */, |
| DTo d_to, DFrom d_from, VFromD<DFrom> v) { |
| const Repartition<double, decltype(d_from)> df64_from; |
| const auto vf64 = BitCast(df64_from, v); |
| #if HWY_HAVE_ZEXT |
| (void)d_to; |
| return Vec512<double>{_mm512_zextpd128_pd512(vf64.raw)}; |
| #else |
| return Vec512<double>{_mm512_insertf64x2(Zero(d_to).raw, vf64.raw, 0)}; |
| #endif |
| } |
| |
| template <class DTo, class DFrom> |
| HWY_INLINE VFromD<DTo> ZeroExtendResizeBitCast( |
| hwy::SizeTag<8> /* from_size_tag */, hwy::SizeTag<64> /* to_size_tag */, |
| DTo d_to, DFrom d_from, VFromD<DFrom> v) { |
| const Twice<decltype(d_from)> dt_from; |
| return ZeroExtendResizeBitCast(hwy::SizeTag<16>(), hwy::SizeTag<64>(), d_to, |
| dt_from, ZeroExtendVector(dt_from, v)); |
| } |
| |
| } // namespace detail |
| |
| // ------------------------------ Combine |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_FLOAT3264_D(D)> |
| HWY_API VFromD<D> Combine(D d, VFromD<Half<D>> hi, VFromD<Half<D>> lo) { |
| const RebindToUnsigned<decltype(d)> du; // for float16_t |
| const Half<decltype(du)> duh; |
| const __m512i lo512 = ZeroExtendVector(du, BitCast(duh, lo)).raw; |
| return BitCast(d, VFromD<decltype(du)>{ |
| _mm512_inserti32x8(lo512, BitCast(duh, hi).raw, 1)}); |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> |
| HWY_API VFromD<D> Combine(D d, VFromD<Half<D>> hi, VFromD<Half<D>> lo) { |
| return VFromD<D>{_mm512_insertf32x8(ZeroExtendVector(d, lo).raw, hi.raw, 1)}; |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> |
| HWY_API VFromD<D> Combine(D d, VFromD<Half<D>> hi, VFromD<Half<D>> lo) { |
| return VFromD<D>{_mm512_insertf64x4(ZeroExtendVector(d, lo).raw, hi.raw, 1)}; |
| } |
| |
| // ------------------------------ ShiftLeftBytes |
| template <int kBytes, class D, HWY_IF_V_SIZE_D(D, 64)> |
| HWY_API VFromD<D> ShiftLeftBytes(D /* tag */, const VFromD<D> v) { |
| static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); |
| return VFromD<D>{_mm512_bslli_epi128(v.raw, kBytes)}; |
| } |
| |
| // ------------------------------ ShiftRightBytes |
| template <int kBytes, class D, HWY_IF_V_SIZE_D(D, 64)> |
| HWY_API VFromD<D> ShiftRightBytes(D /* tag */, const VFromD<D> v) { |
| static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); |
| return VFromD<D>{_mm512_bsrli_epi128(v.raw, kBytes)}; |
| } |
| |
| // ------------------------------ CombineShiftRightBytes |
| |
| template <int kBytes, class D, HWY_IF_V_SIZE_D(D, 64)> |
| HWY_API VFromD<D> CombineShiftRightBytes(D d, VFromD<D> hi, VFromD<D> lo) { |
| const Repartition<uint8_t, decltype(d)> d8; |
| return BitCast(d, Vec512<uint8_t>{_mm512_alignr_epi8( |
| BitCast(d8, hi).raw, BitCast(d8, lo).raw, kBytes)}); |
| } |
| |
| // ------------------------------ Broadcast/splat any lane |
| |
| template <int kLane, typename T, HWY_IF_T_SIZE(T, 2)> |
| HWY_API Vec512<T> Broadcast(const Vec512<T> v) { |
| const DFromV<decltype(v)> d; |
| const RebindToUnsigned<decltype(d)> du; |
| using VU = VFromD<decltype(du)>; |
| const VU vu = BitCast(du, v); // for float16_t |
| static_assert(0 <= kLane && kLane < 8, "Invalid lane"); |
| if (kLane < 4) { |
| const __m512i lo = _mm512_shufflelo_epi16(vu.raw, (0x55 * kLane) & 0xFF); |
| return BitCast(d, VU{_mm512_unpacklo_epi64(lo, lo)}); |
| } else { |
| const __m512i hi = |
| _mm512_shufflehi_epi16(vu.raw, (0x55 * (kLane - 4)) & 0xFF); |
| return BitCast(d, VU{_mm512_unpackhi_epi64(hi, hi)}); |
| } |
| } |
| |
| template <int kLane, typename T, HWY_IF_UI32(T)> |
| HWY_API Vec512<T> Broadcast(const Vec512<T> v) { |
| static_assert(0 <= kLane && kLane < 4, "Invalid lane"); |
| constexpr _MM_PERM_ENUM perm = static_cast<_MM_PERM_ENUM>(0x55 * kLane); |
| return Vec512<T>{_mm512_shuffle_epi32(v.raw, perm)}; |
| } |
| |
| template <int kLane, typename T, HWY_IF_UI64(T)> |
| HWY_API Vec512<T> Broadcast(const Vec512<T> v) { |
| static_assert(0 <= kLane && kLane < 2, "Invalid lane"); |
| constexpr _MM_PERM_ENUM perm = kLane ? _MM_PERM_DCDC : _MM_PERM_BABA; |
| return Vec512<T>{_mm512_shuffle_epi32(v.raw, perm)}; |
| } |
| |
| template <int kLane> |
| HWY_API Vec512<float> Broadcast(const Vec512<float> v) { |
| static_assert(0 <= kLane && kLane < 4, "Invalid lane"); |
| constexpr _MM_PERM_ENUM perm = static_cast<_MM_PERM_ENUM>(0x55 * kLane); |
| return Vec512<float>{_mm512_shuffle_ps(v.raw, v.raw, perm)}; |
| } |
| |
| template <int kLane> |
| HWY_API Vec512<double> Broadcast(const Vec512<double> v) { |
| static_assert(0 <= kLane && kLane < 2, "Invalid lane"); |
| constexpr _MM_PERM_ENUM perm = static_cast<_MM_PERM_ENUM>(0xFF * kLane); |
| return Vec512<double>{_mm512_shuffle_pd(v.raw, v.raw, perm)}; |
| } |
| |
| // ------------------------------ BroadcastBlock |
| template <int kBlockIdx, class T> |
| HWY_API Vec512<T> BroadcastBlock(Vec512<T> v) { |
| static_assert(0 <= kBlockIdx && kBlockIdx <= 3, "Invalid block index"); |
| const DFromV<decltype(v)> d; |
| const RebindToUnsigned<decltype(d)> du; // for float16_t |
| return BitCast( |
| d, VFromD<decltype(du)>{_mm512_shuffle_i32x4( |
| BitCast(du, v).raw, BitCast(du, v).raw, 0x55 * kBlockIdx)}); |
| } |
| |
| template <int kBlockIdx> |
| HWY_API Vec512<float> BroadcastBlock(Vec512<float> v) { |
| static_assert(0 <= kBlockIdx && kBlockIdx <= 3, "Invalid block index"); |
| return Vec512<float>{_mm512_shuffle_f32x4(v.raw, v.raw, 0x55 * kBlockIdx)}; |
| } |
| |
| template <int kBlockIdx> |
| HWY_API Vec512<double> BroadcastBlock(Vec512<double> v) { |
| static_assert(0 <= kBlockIdx && kBlockIdx <= 3, "Invalid block index"); |
| return Vec512<double>{_mm512_shuffle_f64x2(v.raw, v.raw, 0x55 * kBlockIdx)}; |
| } |
| |
| // ------------------------------ BroadcastLane |
| |
| namespace detail { |
| |
| template <class T, HWY_IF_T_SIZE(T, 1)> |
| HWY_INLINE Vec512<T> BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, |
| Vec512<T> v) { |
| return Vec512<T>{_mm512_broadcastb_epi8(ResizeBitCast(Full128<T>(), v).raw)}; |
| } |
| |
| template <class T, HWY_IF_T_SIZE(T, 2)> |
| HWY_INLINE Vec512<T> BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, |
| Vec512<T> v) { |
| const DFromV<decltype(v)> d; |
| const RebindToUnsigned<decltype(d)> du; // for float16_t |
| return BitCast(d, VFromD<decltype(du)>{_mm512_broadcastw_epi16( |
| ResizeBitCast(Full128<uint16_t>(), v).raw)}); |
| } |
| |
| template <class T, HWY_IF_UI32(T)> |
| HWY_INLINE Vec512<T> BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, |
| Vec512<T> v) { |
| return Vec512<T>{_mm512_broadcastd_epi32(ResizeBitCast(Full128<T>(), v).raw)}; |
| } |
| |
| template <class T, HWY_IF_UI64(T)> |
| HWY_INLINE Vec512<T> BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, |
| Vec512<T> v) { |
| return Vec512<T>{_mm512_broadcastq_epi64(ResizeBitCast(Full128<T>(), v).raw)}; |
| } |
| |
| HWY_INLINE Vec512<float> BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, |
| Vec512<float> v) { |
| return Vec512<float>{ |
| _mm512_broadcastss_ps(ResizeBitCast(Full128<float>(), v).raw)}; |
| } |
| |
| HWY_INLINE Vec512<double> BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, |
| Vec512<double> v) { |
| return Vec512<double>{ |
| _mm512_broadcastsd_pd(ResizeBitCast(Full128<double>(), v).raw)}; |
| } |
| |
| template <size_t kLaneIdx, class T, hwy::EnableIf<kLaneIdx != 0>* = nullptr> |
| HWY_INLINE Vec512<T> BroadcastLane(hwy::SizeTag<kLaneIdx> /* lane_idx_tag */, |
| Vec512<T> v) { |
| constexpr size_t kLanesPerBlock = 16 / sizeof(T); |
| constexpr int kBlockIdx = static_cast<int>(kLaneIdx / kLanesPerBlock); |
| constexpr int kLaneInBlkIdx = |
| static_cast<int>(kLaneIdx) & (kLanesPerBlock - 1); |
| return Broadcast<kLaneInBlkIdx>(BroadcastBlock<kBlockIdx>(v)); |
| } |
| |
| } // namespace detail |
| |
| template <int kLaneIdx, class T> |
| HWY_API Vec512<T> BroadcastLane(Vec512<T> v) { |
| static_assert(0 <= kLaneIdx, "Invalid lane"); |
| return detail::BroadcastLane(hwy::SizeTag<static_cast<size_t>(kLaneIdx)>(), |
| v); |
| } |
| |
| // ------------------------------ Hard-coded shuffles |
| |
| // Notation: let Vec512<int32_t> have lanes 7,6,5,4,3,2,1,0 (0 is |
| // least-significant). Shuffle0321 rotates four-lane blocks one lane to the |
| // right (the previous least-significant lane is now most-significant => |
| // 47650321). These could also be implemented via CombineShiftRightBytes but |
| // the shuffle_abcd notation is more convenient. |
| |
| // Swap 32-bit halves in 64-bit halves. |
| template <typename T, HWY_IF_UI32(T)> |
| HWY_API Vec512<T> Shuffle2301(const Vec512<T> v) { |
| return Vec512<T>{_mm512_shuffle_epi32(v.raw, _MM_PERM_CDAB)}; |
| } |
| HWY_API Vec512<float> Shuffle2301(const Vec512<float> v) { |
| return Vec512<float>{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_CDAB)}; |
| } |
| |
| namespace detail { |
| |
| template <typename T, HWY_IF_T_SIZE(T, 4)> |
| HWY_API Vec512<T> ShuffleTwo2301(const Vec512<T> a, const Vec512<T> b) { |
| const DFromV<decltype(a)> d; |
| const RebindToFloat<decltype(d)> df; |
| return BitCast( |
| d, Vec512<float>{_mm512_shuffle_ps(BitCast(df, a).raw, BitCast(df, b).raw, |
| _MM_PERM_CDAB)}); |
| } |
| template <typename T, HWY_IF_T_SIZE(T, 4)> |
| HWY_API Vec512<T> ShuffleTwo1230(const Vec512<T> a, const Vec512<T> b) { |
| const DFromV<decltype(a)> d; |
| const RebindToFloat<decltype(d)> df; |
| return BitCast( |
| d, Vec512<float>{_mm512_shuffle_ps(BitCast(df, a).raw, BitCast(df, b).raw, |
| _MM_PERM_BCDA)}); |
| } |
| template <typename T, HWY_IF_T_SIZE(T, 4)> |
| HWY_API Vec512<T> ShuffleTwo3012(const Vec512<T> a, const Vec512<T> b) { |
| const DFromV<decltype(a)> d; |
| const RebindToFloat<decltype(d)> df; |
| return BitCast( |
| d, Vec512<float>{_mm512_shuffle_ps(BitCast(df, a).raw, BitCast(df, b).raw, |
| _MM_PERM_DABC)}); |
| } |
| |
| } // namespace detail |
| |
| // Swap 64-bit halves |
| HWY_API Vec512<uint32_t> Shuffle1032(const Vec512<uint32_t> v) { |
| return Vec512<uint32_t>{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; |
| } |
| HWY_API Vec512<int32_t> Shuffle1032(const Vec512<int32_t> v) { |
| return Vec512<int32_t>{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; |
| } |
| HWY_API Vec512<float> Shuffle1032(const Vec512<float> v) { |
| // Shorter encoding than _mm512_permute_ps. |
| return Vec512<float>{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_BADC)}; |
| } |
| HWY_API Vec512<uint64_t> Shuffle01(const Vec512<uint64_t> v) { |
| return Vec512<uint64_t>{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; |
| } |
| HWY_API Vec512<int64_t> Shuffle01(const Vec512<int64_t> v) { |
| return Vec512<int64_t>{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; |
| } |
| HWY_API Vec512<double> Shuffle01(const Vec512<double> v) { |
| // Shorter encoding than _mm512_permute_pd. |
| return Vec512<double>{_mm512_shuffle_pd(v.raw, v.raw, _MM_PERM_BBBB)}; |
| } |
| |
| // Rotate right 32 bits |
| HWY_API Vec512<uint32_t> Shuffle0321(const Vec512<uint32_t> v) { |
| return Vec512<uint32_t>{_mm512_shuffle_epi32(v.raw, _MM_PERM_ADCB)}; |
| } |
| HWY_API Vec512<int32_t> Shuffle0321(const Vec512<int32_t> v) { |
| return Vec512<int32_t>{_mm512_shuffle_epi32(v.raw, _MM_PERM_ADCB)}; |
| } |
| HWY_API Vec512<float> Shuffle0321(const Vec512<float> v) { |
| return Vec512<float>{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_ADCB)}; |
| } |
| // Rotate left 32 bits |
| HWY_API Vec512<uint32_t> Shuffle2103(const Vec512<uint32_t> v) { |
| return Vec512<uint32_t>{_mm512_shuffle_epi32(v.raw, _MM_PERM_CBAD)}; |
| } |
| HWY_API Vec512<int32_t> Shuffle2103(const Vec512<int32_t> v) { |
| return Vec512<int32_t>{_mm512_shuffle_epi32(v.raw, _MM_PERM_CBAD)}; |
| } |
| HWY_API Vec512<float> Shuffle2103(const Vec512<float> v) { |
| return Vec512<float>{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_CBAD)}; |
| } |
| |
| // Reverse |
| HWY_API Vec512<uint32_t> Shuffle0123(const Vec512<uint32_t> v) { |
| return Vec512<uint32_t>{_mm512_shuffle_epi32(v.raw, _MM_PERM_ABCD)}; |
| } |
| HWY_API Vec512<int32_t> Shuffle0123(const Vec512<int32_t> v) { |
| return Vec512<int32_t>{_mm512_shuffle_epi32(v.raw, _MM_PERM_ABCD)}; |
| } |
| HWY_API Vec512<float> Shuffle0123(const Vec512<float> v) { |
| return Vec512<float>{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_ABCD)}; |
| } |
| |
| // ------------------------------ TableLookupLanes |
| |
| // Returned by SetTableIndices/IndicesFromVec for use by TableLookupLanes. |
| template <typename T> |
| struct Indices512 { |
| __m512i raw; |
| }; |
| |
| template <class D, typename T = TFromD<D>, typename TI> |
| HWY_API Indices512<T> IndicesFromVec(D /* tag */, Vec512<TI> vec) { |
| static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); |
| #if HWY_IS_DEBUG_BUILD |
| const DFromV<decltype(vec)> di; |
| const RebindToUnsigned<decltype(di)> du; |
| using TU = MakeUnsigned<T>; |
| const auto vec_u = BitCast(du, vec); |
| HWY_DASSERT( |
| AllTrue(du, Lt(vec_u, Set(du, static_cast<TU>(128 / sizeof(T)))))); |
| #endif |
| return Indices512<T>{vec.raw}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), typename TI> |
| HWY_API Indices512<TFromD<D>> SetTableIndices(D d, const TI* idx) { |
| const Rebind<TI, decltype(d)> di; |
| return IndicesFromVec(d, LoadU(di, idx)); |
| } |
| |
| template <typename T, HWY_IF_T_SIZE(T, 1)> |
| HWY_API Vec512<T> TableLookupLanes(Vec512<T> v, Indices512<T> idx) { |
| #if HWY_TARGET <= HWY_AVX3_DL |
| return Vec512<T>{_mm512_permutexvar_epi8(idx.raw, v.raw)}; |
| #else |
| const DFromV<decltype(v)> d; |
| const Repartition<uint16_t, decltype(d)> du16; |
| const Vec512<T> idx_vec{idx.raw}; |
| |
| const auto bd_sel_mask = |
| MaskFromVec(BitCast(d, ShiftLeft<3>(BitCast(du16, idx_vec)))); |
| const auto cd_sel_mask = |
| MaskFromVec(BitCast(d, ShiftLeft<2>(BitCast(du16, idx_vec)))); |
| |
| const Vec512<T> v_a{_mm512_shuffle_i32x4(v.raw, v.raw, 0x00)}; |
| const Vec512<T> v_b{_mm512_shuffle_i32x4(v.raw, v.raw, 0x55)}; |
| const Vec512<T> v_c{_mm512_shuffle_i32x4(v.raw, v.raw, 0xAA)}; |
| const Vec512<T> v_d{_mm512_shuffle_i32x4(v.raw, v.raw, 0xFF)}; |
| |
| const auto shuf_a = TableLookupBytes(v_a, idx_vec); |
| const auto shuf_c = TableLookupBytes(v_c, idx_vec); |
| const Vec512<T> shuf_ab{_mm512_mask_shuffle_epi8(shuf_a.raw, bd_sel_mask.raw, |
| v_b.raw, idx_vec.raw)}; |
| const Vec512<T> shuf_cd{_mm512_mask_shuffle_epi8(shuf_c.raw, bd_sel_mask.raw, |
| v_d.raw, idx_vec.raw)}; |
| return IfThenElse(cd_sel_mask, shuf_cd, shuf_ab); |
| #endif |
| } |
| |
| template <typename T, HWY_IF_T_SIZE(T, 2), HWY_IF_NOT_SPECIAL_FLOAT(T)> |
| HWY_API Vec512<T> TableLookupLanes(Vec512<T> v, Indices512<T> idx) { |
| return Vec512<T>{_mm512_permutexvar_epi16(idx.raw, v.raw)}; |
| } |
| #if HWY_HAVE_FLOAT16 |
| HWY_API Vec512<float16_t> TableLookupLanes(Vec512<float16_t> v, |
| Indices512<float16_t> idx) { |
| return Vec512<float16_t>{_mm512_permutexvar_ph(idx.raw, v.raw)}; |
| } |
| #endif // HWY_HAVE_FLOAT16 |
| template <typename T, HWY_IF_T_SIZE(T, 4)> |
| HWY_API Vec512<T> TableLookupLanes(Vec512<T> v, Indices512<T> idx) { |
| return Vec512<T>{_mm512_permutexvar_epi32(idx.raw, v.raw)}; |
| } |
| |
| template <typename T, HWY_IF_T_SIZE(T, 8)> |
| HWY_API Vec512<T> TableLookupLanes(Vec512<T> v, Indices512<T> idx) { |
| return Vec512<T>{_mm512_permutexvar_epi64(idx.raw, v.raw)}; |
| } |
| |
| HWY_API Vec512<float> TableLookupLanes(Vec512<float> v, Indices512<float> idx) { |
| return Vec512<float>{_mm512_permutexvar_ps(idx.raw, v.raw)}; |
| } |
| |
| HWY_API Vec512<double> TableLookupLanes(Vec512<double> v, |
| Indices512<double> idx) { |
| return Vec512<double>{_mm512_permutexvar_pd(idx.raw, v.raw)}; |
| } |
| |
| template <typename T, HWY_IF_T_SIZE(T, 1)> |
| HWY_API Vec512<T> TwoTablesLookupLanes(Vec512<T> a, Vec512<T> b, |
| Indices512<T> idx) { |
| #if HWY_TARGET <= HWY_AVX3_DL |
| return Vec512<T>{_mm512_permutex2var_epi8(a.raw, idx.raw, b.raw)}; |
| #else |
| const DFromV<decltype(a)> d; |
| const auto b_sel_mask = |
| MaskFromVec(BitCast(d, ShiftLeft<1>(Vec512<uint16_t>{idx.raw}))); |
| return IfThenElse(b_sel_mask, TableLookupLanes(b, idx), |
| TableLookupLanes(a, idx)); |
| #endif |
| } |
| |
| template <typename T, HWY_IF_T_SIZE(T, 2)> |
| HWY_API Vec512<T> TwoTablesLookupLanes(Vec512<T> a, Vec512<T> b, |
| Indices512<T> idx) { |
| return Vec512<T>{_mm512_permutex2var_epi16(a.raw, idx.raw, b.raw)}; |
| } |
| |
| template <typename T, HWY_IF_UI32(T)> |
| HWY_API Vec512<T> TwoTablesLookupLanes(Vec512<T> a, Vec512<T> b, |
| Indices512<T> idx) { |
| return Vec512<T>{_mm512_permutex2var_epi32(a.raw, idx.raw, b.raw)}; |
| } |
| |
| #if HWY_HAVE_FLOAT16 |
| HWY_API Vec512<float16_t> TwoTablesLookupLanes(Vec512<float16_t> a, |
| Vec512<float16_t> b, |
| Indices512<float16_t> idx) { |
| return Vec512<float16_t>{_mm512_permutex2var_ph(a.raw, idx.raw, b.raw)}; |
| } |
| #endif // HWY_HAVE_FLOAT16 |
| HWY_API Vec512<float> TwoTablesLookupLanes(Vec512<float> a, Vec512<float> b, |
| Indices512<float> idx) { |
| return Vec512<float>{_mm512_permutex2var_ps(a.raw, idx.raw, b.raw)}; |
| } |
| |
| template <typename T, HWY_IF_UI64(T)> |
| HWY_API Vec512<T> TwoTablesLookupLanes(Vec512<T> a, Vec512<T> b, |
| Indices512<T> idx) { |
| return Vec512<T>{_mm512_permutex2var_epi64(a.raw, idx.raw, b.raw)}; |
| } |
| |
| HWY_API Vec512<double> TwoTablesLookupLanes(Vec512<double> a, Vec512<double> b, |
| Indices512<double> idx) { |
| return Vec512<double>{_mm512_permutex2var_pd(a.raw, idx.raw, b.raw)}; |
| } |
| |
| // ------------------------------ Reverse |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 1)> |
| HWY_API VFromD<D> Reverse(D d, const VFromD<D> v) { |
| #if HWY_TARGET <= HWY_AVX3_DL |
| const RebindToSigned<decltype(d)> di; |
| alignas(64) static constexpr int8_t kReverse[64] = { |
| 63, 62, 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, |
| 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, |
| 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, |
| 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}; |
| const Vec512<int8_t> idx = Load(di, kReverse); |
| return BitCast( |
| d, Vec512<int8_t>{_mm512_permutexvar_epi8(idx.raw, BitCast(di, v).raw)}); |
| #else |
| const RepartitionToWide<decltype(d)> d16; |
| return BitCast(d, Reverse(d16, RotateRight<8>(BitCast(d16, v)))); |
| #endif |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 2)> |
| HWY_API VFromD<D> Reverse(D d, const VFromD<D> v) { |
| const RebindToSigned<decltype(d)> di; |
| alignas(64) static constexpr int16_t kReverse[32] = { |
| 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, |
| 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}; |
| const Vec512<int16_t> idx = Load(di, kReverse); |
| return BitCast(d, Vec512<int16_t>{ |
| _mm512_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 4)> |
| HWY_API VFromD<D> Reverse(D d, const VFromD<D> v) { |
| alignas(64) static constexpr int32_t kReverse[16] = { |
| 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}; |
| return TableLookupLanes(v, SetTableIndices(d, kReverse)); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 8)> |
| HWY_API VFromD<D> Reverse(D d, const VFromD<D> v) { |
| alignas(64) static constexpr int64_t kReverse[8] = {7, 6, 5, 4, 3, 2, 1, 0}; |
| return TableLookupLanes(v, SetTableIndices(d, kReverse)); |
| } |
| |
| // ------------------------------ Reverse2 (in x86_128) |
| |
| // ------------------------------ Reverse4 |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 2)> |
| HWY_API VFromD<D> Reverse4(D d, const VFromD<D> v) { |
| const RebindToSigned<decltype(d)> di; |
| alignas(64) static constexpr int16_t kReverse4[32] = { |
| 3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8, 15, 14, 13, 12, |
| 19, 18, 17, 16, 23, 22, 21, 20, 27, 26, 25, 24, 31, 30, 29, 28}; |
| const Vec512<int16_t> idx = Load(di, kReverse4); |
| return BitCast(d, Vec512<int16_t>{ |
| _mm512_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); |
| } |
| |
| // 32 bit Reverse4 defined in x86_128. |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI64_D(D)> |
| HWY_API VFromD<D> Reverse4(D /* tag */, const VFromD<D> v) { |
| return VFromD<D>{_mm512_permutex_epi64(v.raw, _MM_SHUFFLE(0, 1, 2, 3))}; |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> |
| HWY_API VFromD<D> Reverse4(D /* tag */, VFromD<D> v) { |
| return VFromD<D>{_mm512_permutex_pd(v.raw, _MM_SHUFFLE(0, 1, 2, 3))}; |
| } |
| |
| // ------------------------------ Reverse8 |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 2)> |
| HWY_API VFromD<D> Reverse8(D d, const VFromD<D> v) { |
| const RebindToSigned<decltype(d)> di; |
| alignas(64) static constexpr int16_t kReverse8[32] = { |
| 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8, |
| 23, 22, 21, 20, 19, 18, 17, 16, 31, 30, 29, 28, 27, 26, 25, 24}; |
| const Vec512<int16_t> idx = Load(di, kReverse8); |
| return BitCast(d, Vec512<int16_t>{ |
| _mm512_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 4)> |
| HWY_API VFromD<D> Reverse8(D d, const VFromD<D> v) { |
| const RebindToSigned<decltype(d)> di; |
| alignas(64) static constexpr int32_t kReverse8[16] = { |
| 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8}; |
| const Vec512<int32_t> idx = Load(di, kReverse8); |
| return BitCast(d, Vec512<int32_t>{ |
| _mm512_permutexvar_epi32(idx.raw, BitCast(di, v).raw)}); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 8)> |
| HWY_API VFromD<D> Reverse8(D d, const VFromD<D> v) { |
| return Reverse(d, v); |
| } |
| |
| // ------------------------------ ReverseBits (GaloisAffine) |
| |
| #if HWY_TARGET <= HWY_AVX3_DL |
| |
| #ifdef HWY_NATIVE_REVERSE_BITS_UI8 |
| #undef HWY_NATIVE_REVERSE_BITS_UI8 |
| #else |
| #define HWY_NATIVE_REVERSE_BITS_UI8 |
| #endif |
| |
| // Generic for all vector lengths. Must be defined after all GaloisAffine. |
| template <class V, HWY_IF_T_SIZE_V(V, 1)> |
| HWY_API V ReverseBits(V v) { |
| const Repartition<uint64_t, DFromV<V>> du64; |
| return detail::GaloisAffine(v, Set(du64, 0x8040201008040201u)); |
| } |
| |
| #endif // HWY_TARGET <= HWY_AVX3_DL |
| |
| // ------------------------------ InterleaveLower |
| |
| template <typename T, HWY_IF_T_SIZE(T, 1)> |
| HWY_API Vec512<T> InterleaveLower(Vec512<T> a, Vec512<T> b) { |
| return Vec512<T>{_mm512_unpacklo_epi8(a.raw, b.raw)}; |
| } |
| template <typename T, HWY_IF_T_SIZE(T, 2)> |
| HWY_API Vec512<T> InterleaveLower(Vec512<T> a, Vec512<T> b) { |
| const DFromV<decltype(a)> d; |
| const RebindToUnsigned<decltype(d)> du; |
| using VU = VFromD<decltype(du)>; // for float16_t |
| return BitCast( |
| d, VU{_mm512_unpacklo_epi16(BitCast(du, a).raw, BitCast(du, b).raw)}); |
| } |
| template <typename T, HWY_IF_T_SIZE(T, 4)> |
| HWY_API Vec512<T> InterleaveLower(Vec512<T> a, Vec512<T> b) { |
| return Vec512<T>{_mm512_unpacklo_epi32(a.raw, b.raw)}; |
| } |
| template <typename T, HWY_IF_T_SIZE(T, 8)> |
| HWY_API Vec512<T> InterleaveLower(Vec512<T> a, Vec512<T> b) { |
| return Vec512<T>{_mm512_unpacklo_epi64(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<float> InterleaveLower(Vec512<float> a, Vec512<float> b) { |
| return Vec512<float>{_mm512_unpacklo_ps(a.raw, b.raw)}; |
| } |
| HWY_API Vec512<double> InterleaveLower(Vec512<double> a, Vec512<double> b) { |
| return Vec512<double>{_mm512_unpacklo_pd(a.raw, b.raw)}; |
| } |
| |
| // ------------------------------ InterleaveUpper |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 1)> |
| HWY_API VFromD<D> InterleaveUpper(D /* tag */, VFromD<D> a, VFromD<D> b) { |
| return VFromD<D>{_mm512_unpackhi_epi8(a.raw, b.raw)}; |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 2)> |
| HWY_API VFromD<D> InterleaveUpper(D d, VFromD<D> a, VFromD<D> b) { |
| const RebindToUnsigned<decltype(d)> du; |
| using VU = VFromD<decltype(du)>; // for float16_t |
| return BitCast( |
| d, VU{_mm512_unpackhi_epi16(BitCast(du, a).raw, BitCast(du, b).raw)}); |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI32_D(D)> |
| HWY_API VFromD<D> InterleaveUpper(D /* tag */, VFromD<D> a, VFromD<D> b) { |
| return VFromD<D>{_mm512_unpackhi_epi32(a.raw, b.raw)}; |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI64_D(D)> |
| HWY_API VFromD<D> InterleaveUpper(D /* tag */, VFromD<D> a, VFromD<D> b) { |
| return VFromD<D>{_mm512_unpackhi_epi64(a.raw, b.raw)}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> |
| HWY_API VFromD<D> InterleaveUpper(D /* tag */, VFromD<D> a, VFromD<D> b) { |
| return VFromD<D>{_mm512_unpackhi_ps(a.raw, b.raw)}; |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> |
| HWY_API VFromD<D> InterleaveUpper(D /* tag */, VFromD<D> a, VFromD<D> b) { |
| return VFromD<D>{_mm512_unpackhi_pd(a.raw, b.raw)}; |
| } |
| |
| // ------------------------------ Concat* halves |
| |
| // hiH,hiL loH,loL |-> hiL,loL (= lower halves) |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_FLOAT3264_D(D)> |
| HWY_API VFromD<D> ConcatLowerLower(D d, VFromD<D> hi, VFromD<D> lo) { |
| const RebindToUnsigned<decltype(d)> du; // for float16_t |
| return BitCast(d, |
| VFromD<decltype(du)>{_mm512_shuffle_i32x4( |
| BitCast(du, lo).raw, BitCast(du, hi).raw, _MM_PERM_BABA)}); |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> |
| HWY_API VFromD<D> ConcatLowerLower(D /* tag */, VFromD<D> hi, VFromD<D> lo) { |
| return VFromD<D>{_mm512_shuffle_f32x4(lo.raw, hi.raw, _MM_PERM_BABA)}; |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> |
| HWY_API Vec512<double> ConcatLowerLower(D /* tag */, Vec512<double> hi, |
| Vec512<double> lo) { |
| return Vec512<double>{_mm512_shuffle_f64x2(lo.raw, hi.raw, _MM_PERM_BABA)}; |
| } |
| |
| // hiH,hiL loH,loL |-> hiH,loH (= upper halves) |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_FLOAT3264_D(D)> |
| HWY_API VFromD<D> ConcatUpperUpper(D d, VFromD<D> hi, VFromD<D> lo) { |
| const RebindToUnsigned<decltype(d)> du; // for float16_t |
| return BitCast(d, |
| VFromD<decltype(du)>{_mm512_shuffle_i32x4( |
| BitCast(du, lo).raw, BitCast(du, hi).raw, _MM_PERM_DCDC)}); |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> |
| HWY_API VFromD<D> ConcatUpperUpper(D /* tag */, VFromD<D> hi, VFromD<D> lo) { |
| return VFromD<D>{_mm512_shuffle_f32x4(lo.raw, hi.raw, _MM_PERM_DCDC)}; |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> |
| HWY_API Vec512<double> ConcatUpperUpper(D /* tag */, Vec512<double> hi, |
| Vec512<double> lo) { |
| return Vec512<double>{_mm512_shuffle_f64x2(lo.raw, hi.raw, _MM_PERM_DCDC)}; |
| } |
| |
| // hiH,hiL loH,loL |-> hiL,loH (= inner halves / swap blocks) |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_FLOAT3264_D(D)> |
| HWY_API VFromD<D> ConcatLowerUpper(D d, VFromD<D> hi, VFromD<D> lo) { |
| const RebindToUnsigned<decltype(d)> du; // for float16_t |
| return BitCast(d, |
| VFromD<decltype(du)>{_mm512_shuffle_i32x4( |
| BitCast(du, lo).raw, BitCast(du, hi).raw, _MM_PERM_BADC)}); |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> |
| HWY_API VFromD<D> ConcatLowerUpper(D /* tag */, VFromD<D> hi, VFromD<D> lo) { |
| return VFromD<D>{_mm512_shuffle_f32x4(lo.raw, hi.raw, _MM_PERM_BADC)}; |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> |
| HWY_API Vec512<double> ConcatLowerUpper(D /* tag */, Vec512<double> hi, |
| Vec512<double> lo) { |
| return Vec512<double>{_mm512_shuffle_f64x2(lo.raw, hi.raw, _MM_PERM_BADC)}; |
| } |
| |
| // hiH,hiL loH,loL |-> hiH,loL (= outer halves) |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_FLOAT3264_D(D)> |
| HWY_API VFromD<D> ConcatUpperLower(D d, VFromD<D> hi, VFromD<D> lo) { |
| // There are no imm8 blend in AVX512. Use blend16 because 32-bit masks |
| // are efficiently loaded from 32-bit regs. |
| const __mmask32 mask = /*_cvtu32_mask32 */ (0x0000FFFF); |
| const RebindToUnsigned<decltype(d)> du; // for float16_t |
| return BitCast(d, VFromD<decltype(du)>{_mm512_mask_blend_epi16( |
| mask, BitCast(du, hi).raw, BitCast(du, lo).raw)}); |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> |
| HWY_API VFromD<D> ConcatUpperLower(D /* tag */, VFromD<D> hi, VFromD<D> lo) { |
| const __mmask16 mask = /*_cvtu32_mask16 */ (0x00FF); |
| return VFromD<D>{_mm512_mask_blend_ps(mask, hi.raw, lo.raw)}; |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> |
| HWY_API Vec512<double> ConcatUpperLower(D /* tag */, Vec512<double> hi, |
| Vec512<double> lo) { |
| const __mmask8 mask = /*_cvtu32_mask8 */ (0x0F); |
| return Vec512<double>{_mm512_mask_blend_pd(mask, hi.raw, lo.raw)}; |
| } |
| |
| // ------------------------------ ConcatOdd |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 1)> |
| HWY_API VFromD<D> ConcatOdd(D d, VFromD<D> hi, VFromD<D> lo) { |
| const RebindToUnsigned<decltype(d)> du; |
| #if HWY_TARGET <= HWY_AVX3_DL |
| alignas(64) static constexpr uint8_t kIdx[64] = { |
| 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, |
| 27, 29, 31, 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, |
| 53, 55, 57, 59, 61, 63, 65, 67, 69, 71, 73, 75, 77, |
| 79, 81, 83, 85, 87, 89, 91, 93, 95, 97, 99, 101, 103, |
| 105, 107, 109, 111, 113, 115, 117, 119, 121, 123, 125, 127}; |
| return BitCast( |
| d, Vec512<uint8_t>{_mm512_permutex2var_epi8( |
| BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); |
| #else |
| const RepartitionToWide<decltype(du)> dw; |
| // Right-shift 8 bits per u16 so we can pack. |
| const Vec512<uint16_t> uH = ShiftRight<8>(BitCast(dw, hi)); |
| const Vec512<uint16_t> uL = ShiftRight<8>(BitCast(dw, lo)); |
| const Vec512<uint64_t> u8{_mm512_packus_epi16(uL.raw, uH.raw)}; |
| // Undo block interleave: lower half = even u64 lanes, upper = odd u64 lanes. |
| const Full512<uint64_t> du64; |
| alignas(64) static constexpr uint64_t kIdx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; |
| return BitCast(d, TableLookupLanes(u8, SetTableIndices(du64, kIdx))); |
| #endif |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 2)> |
| HWY_API VFromD<D> ConcatOdd(D d, VFromD<D> hi, VFromD<D> lo) { |
| const RebindToUnsigned<decltype(d)> du; |
| alignas(64) static constexpr uint16_t kIdx[32] = { |
| 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, |
| 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 61, 63}; |
| return BitCast( |
| d, Vec512<uint16_t>{_mm512_permutex2var_epi16( |
| BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI32_D(D)> |
| HWY_API VFromD<D> ConcatOdd(D d, VFromD<D> hi, VFromD<D> lo) { |
| const RebindToUnsigned<decltype(d)> du; |
| alignas(64) static constexpr uint32_t kIdx[16] = { |
| 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31}; |
| return BitCast( |
| d, Vec512<uint32_t>{_mm512_permutex2var_epi32( |
| BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> |
| HWY_API VFromD<D> ConcatOdd(D d, VFromD<D> hi, VFromD<D> lo) { |
| const RebindToUnsigned<decltype(d)> du; |
| alignas(64) static constexpr uint32_t kIdx[16] = { |
| 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31}; |
| return VFromD<D>{_mm512_permutex2var_ps(lo.raw, Load(du, kIdx).raw, hi.raw)}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI64_D(D)> |
| HWY_API VFromD<D> ConcatOdd(D d, VFromD<D> hi, VFromD<D> lo) { |
| const RebindToUnsigned<decltype(d)> du; |
| alignas(64) static constexpr uint64_t kIdx[8] = {1, 3, 5, 7, 9, 11, 13, 15}; |
| return BitCast( |
| d, Vec512<uint64_t>{_mm512_permutex2var_epi64( |
| BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> |
| HWY_API VFromD<D> ConcatOdd(D d, VFromD<D> hi, VFromD<D> lo) { |
| const RebindToUnsigned<decltype(d)> du; |
| alignas(64) static constexpr uint64_t kIdx[8] = {1, 3, 5, 7, 9, 11, 13, 15}; |
| return VFromD<D>{_mm512_permutex2var_pd(lo.raw, Load(du, kIdx).raw, hi.raw)}; |
| } |
| |
| // ------------------------------ ConcatEven |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 1)> |
| HWY_API VFromD<D> ConcatEven(D d, VFromD<D> hi, VFromD<D> lo) { |
| const RebindToUnsigned<decltype(d)> du; |
| #if HWY_TARGET <= HWY_AVX3_DL |
| alignas(64) static constexpr uint8_t kIdx[64] = { |
| 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, |
| 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, |
| 52, 54, 56, 58, 60, 62, 64, 66, 68, 70, 72, 74, 76, |
| 78, 80, 82, 84, 86, 88, 90, 92, 94, 96, 98, 100, 102, |
| 104, 106, 108, 110, 112, 114, 116, 118, 120, 122, 124, 126}; |
| return BitCast( |
| d, Vec512<uint32_t>{_mm512_permutex2var_epi8( |
| BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); |
| #else |
| const RepartitionToWide<decltype(du)> dw; |
| // Isolate lower 8 bits per u16 so we can pack. |
| const Vec512<uint16_t> mask = Set(dw, 0x00FF); |
| const Vec512<uint16_t> uH = And(BitCast(dw, hi), mask); |
| const Vec512<uint16_t> uL = And(BitCast(dw, lo), mask); |
| const Vec512<uint64_t> u8{_mm512_packus_epi16(uL.raw, uH.raw)}; |
| // Undo block interleave: lower half = even u64 lanes, upper = odd u64 lanes. |
| const Full512<uint64_t> du64; |
| alignas(64) static constexpr uint64_t kIdx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; |
| return BitCast(d, TableLookupLanes(u8, SetTableIndices(du64, kIdx))); |
| #endif |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 2)> |
| HWY_API VFromD<D> ConcatEven(D d, VFromD<D> hi, VFromD<D> lo) { |
| const RebindToUnsigned<decltype(d)> du; |
| alignas(64) static constexpr uint16_t kIdx[32] = { |
| 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, |
| 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62}; |
| return BitCast( |
| d, Vec512<uint32_t>{_mm512_permutex2var_epi16( |
| BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI32_D(D)> |
| HWY_API VFromD<D> ConcatEven(D d, VFromD<D> hi, VFromD<D> lo) { |
| const RebindToUnsigned<decltype(d)> du; |
| alignas(64) static constexpr uint32_t kIdx[16] = { |
| 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30}; |
| return BitCast( |
| d, Vec512<uint32_t>{_mm512_permutex2var_epi32( |
| BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> |
| HWY_API VFromD<D> ConcatEven(D d, VFromD<D> hi, VFromD<D> lo) { |
| const RebindToUnsigned<decltype(d)> du; |
| alignas(64) static constexpr uint32_t kIdx[16] = { |
| 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30}; |
| return VFromD<D>{_mm512_permutex2var_ps(lo.raw, Load(du, kIdx).raw, hi.raw)}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI64_D(D)> |
| HWY_API VFromD<D> ConcatEven(D d, VFromD<D> hi, VFromD<D> lo) { |
| const RebindToUnsigned<decltype(d)> du; |
| alignas(64) static constexpr uint64_t kIdx[8] = {0, 2, 4, 6, 8, 10, 12, 14}; |
| return BitCast( |
| d, Vec512<uint64_t>{_mm512_permutex2var_epi64( |
| BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> |
| HWY_API VFromD<D> ConcatEven(D d, VFromD<D> hi, VFromD<D> lo) { |
| const RebindToUnsigned<decltype(d)> du; |
| alignas(64) static constexpr uint64_t kIdx[8] = {0, 2, 4, 6, 8, 10, 12, 14}; |
| return VFromD<D>{_mm512_permutex2var_pd(lo.raw, Load(du, kIdx).raw, hi.raw)}; |
| } |
| |
| // ------------------------------ InterleaveWholeLower |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 1)> |
| HWY_API VFromD<D> InterleaveWholeLower(D d, VFromD<D> a, VFromD<D> b) { |
| #if HWY_TARGET <= HWY_AVX3_DL |
| const RebindToUnsigned<decltype(d)> du; |
| alignas(64) static constexpr uint8_t kIdx[64] = { |
| 0, 64, 1, 65, 2, 66, 3, 67, 4, 68, 5, 69, 6, 70, 7, 71, |
| 8, 72, 9, 73, 10, 74, 11, 75, 12, 76, 13, 77, 14, 78, 15, 79, |
| 16, 80, 17, 81, 18, 82, 19, 83, 20, 84, 21, 85, 22, 86, 23, 87, |
| 24, 88, 25, 89, 26, 90, 27, 91, 28, 92, 29, 93, 30, 94, 31, 95}; |
| return VFromD<D>{_mm512_permutex2var_epi8(a.raw, Load(du, kIdx).raw, b.raw)}; |
| #else |
| alignas(64) static constexpr uint64_t kIdx2[8] = {0, 1, 8, 9, 2, 3, 10, 11}; |
| const Repartition<uint64_t, decltype(d)> du64; |
| return VFromD<D>{_mm512_permutex2var_epi64(InterleaveLower(a, b).raw, |
| Load(du64, kIdx2).raw, |
| InterleaveUpper(d, a, b).raw)}; |
| #endif |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 2)> |
| HWY_API VFromD<D> InterleaveWholeLower(D d, VFromD<D> a, VFromD<D> b) { |
| const RebindToUnsigned<decltype(d)> du; |
| alignas(64) static constexpr uint16_t kIdx[32] = { |
| 0, 32, 1, 33, 2, 34, 3, 35, 4, 36, 5, 37, 6, 38, 7, 39, |
| 8, 40, 9, 41, 10, 42, 11, 43, 12, 44, 13, 45, 14, 46, 15, 47}; |
| return BitCast( |
| d, VFromD<decltype(du)>{_mm512_permutex2var_epi16( |
| BitCast(du, a).raw, Load(du, kIdx).raw, BitCast(du, b).raw)}); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI32_D(D)> |
| HWY_API VFromD<D> InterleaveWholeLower(D d, VFromD<D> a, VFromD<D> b) { |
| const RebindToUnsigned<decltype(d)> du; |
| alignas(64) static constexpr uint32_t kIdx[16] = {0, 16, 1, 17, 2, 18, 3, 19, |
| 4, 20, 5, 21, 6, 22, 7, 23}; |
| return VFromD<D>{_mm512_permutex2var_epi32(a.raw, Load(du, kIdx).raw, b.raw)}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> |
| HWY_API VFromD<D> InterleaveWholeLower(D d, VFromD<D> a, VFromD<D> b) { |
| const RebindToUnsigned<decltype(d)> du; |
| alignas(64) static constexpr uint32_t kIdx[16] = {0, 16, 1, 17, 2, 18, 3, 19, |
| 4, 20, 5, 21, 6, 22, 7, 23}; |
| return VFromD<D>{_mm512_permutex2var_ps(a.raw, Load(du, kIdx).raw, b.raw)}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI64_D(D)> |
| HWY_API VFromD<D> InterleaveWholeLower(D d, VFromD<D> a, VFromD<D> b) { |
| const RebindToUnsigned<decltype(d)> du; |
| alignas(64) static constexpr uint64_t kIdx[8] = {0, 8, 1, 9, 2, 10, 3, 11}; |
| return VFromD<D>{_mm512_permutex2var_epi64(a.raw, Load(du, kIdx).raw, b.raw)}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> |
| HWY_API VFromD<D> InterleaveWholeLower(D d, VFromD<D> a, VFromD<D> b) { |
| const RebindToUnsigned<decltype(d)> du; |
| alignas(64) static constexpr uint64_t kIdx[8] = {0, 8, 1, 9, 2, 10, 3, 11}; |
| return VFromD<D>{_mm512_permutex2var_pd(a.raw, Load(du, kIdx).raw, b.raw)}; |
| } |
| |
| // ------------------------------ InterleaveWholeUpper |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 1)> |
| HWY_API VFromD<D> InterleaveWholeUpper(D d, VFromD<D> a, VFromD<D> b) { |
| #if HWY_TARGET <= HWY_AVX3_DL |
| const RebindToUnsigned<decltype(d)> du; |
| alignas(64) static constexpr uint8_t kIdx[64] = { |
| 32, 96, 33, 97, 34, 98, 35, 99, 36, 100, 37, 101, 38, 102, 39, 103, |
| 40, 104, 41, 105, 42, 106, 43, 107, 44, 108, 45, 109, 46, 110, 47, 111, |
| 48, 112, 49, 113, 50, 114, 51, 115, 52, 116, 53, 117, 54, 118, 55, 119, |
| 56, 120, 57, 121, 58, 122, 59, 123, 60, 124, 61, 125, 62, 126, 63, 127}; |
| return VFromD<D>{_mm512_permutex2var_epi8(a.raw, Load(du, kIdx).raw, b.raw)}; |
| #else |
| alignas(64) static constexpr uint64_t kIdx2[8] = {4, 5, 12, 13, 6, 7, 14, 15}; |
| const Repartition<uint64_t, decltype(d)> du64; |
| return VFromD<D>{_mm512_permutex2var_epi64(InterleaveLower(a, b).raw, |
| Load(du64, kIdx2).raw, |
| InterleaveUpper(d, a, b).raw)}; |
| #endif |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 2)> |
| HWY_API VFromD<D> InterleaveWholeUpper(D d, VFromD<D> a, VFromD<D> b) { |
| const RebindToUnsigned<decltype(d)> du; |
| alignas(64) static constexpr uint16_t kIdx[32] = { |
| 16, 48, 17, 49, 18, 50, 19, 51, 20, 52, 21, 53, 22, 54, 23, 55, |
| 24, 56, 25, 57, 26, 58, 27, 59, 28, 60, 29, 61, 30, 62, 31, 63}; |
| return BitCast( |
| d, VFromD<decltype(du)>{_mm512_permutex2var_epi16( |
| BitCast(du, a).raw, Load(du, kIdx).raw, BitCast(du, b).raw)}); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI32_D(D)> |
| HWY_API VFromD<D> InterleaveWholeUpper(D d, VFromD<D> a, VFromD<D> b) { |
| const RebindToUnsigned<decltype(d)> du; |
| alignas(64) static constexpr uint32_t kIdx[16] = { |
| 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31}; |
| return VFromD<D>{_mm512_permutex2var_epi32(a.raw, Load(du, kIdx).raw, b.raw)}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> |
| HWY_API VFromD<D> InterleaveWholeUpper(D d, VFromD<D> a, VFromD<D> b) { |
| const RebindToUnsigned<decltype(d)> du; |
| alignas(64) static constexpr uint32_t kIdx[16] = { |
| 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31}; |
| return VFromD<D>{_mm512_permutex2var_ps(a.raw, Load(du, kIdx).raw, b.raw)}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI64_D(D)> |
| HWY_API VFromD<D> InterleaveWholeUpper(D d, VFromD<D> a, VFromD<D> b) { |
| const RebindToUnsigned<decltype(d)> du; |
| alignas(64) static constexpr uint64_t kIdx[8] = {4, 12, 5, 13, 6, 14, 7, 15}; |
| return VFromD<D>{_mm512_permutex2var_epi64(a.raw, Load(du, kIdx).raw, b.raw)}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> |
| HWY_API VFromD<D> InterleaveWholeUpper(D d, VFromD<D> a, VFromD<D> b) { |
| const RebindToUnsigned<decltype(d)> du; |
| alignas(64) static constexpr uint64_t kIdx[8] = {4, 12, 5, 13, 6, 14, 7, 15}; |
| return VFromD<D>{_mm512_permutex2var_pd(a.raw, Load(du, kIdx).raw, b.raw)}; |
| } |
| |
| // ------------------------------ DupEven (InterleaveLower) |
| |
| template <typename T, HWY_IF_T_SIZE(T, 4)> |
| HWY_API Vec512<T> DupEven(Vec512<T> v) { |
| return Vec512<T>{_mm512_shuffle_epi32(v.raw, _MM_PERM_CCAA)}; |
| } |
| HWY_API Vec512<float> DupEven(Vec512<float> v) { |
| return Vec512<float>{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_CCAA)}; |
| } |
| |
| template <typename T, HWY_IF_T_SIZE(T, 8)> |
| HWY_API Vec512<T> DupEven(const Vec512<T> v) { |
| const DFromV<decltype(v)> d; |
| return InterleaveLower(d, v, v); |
| } |
| |
| // ------------------------------ DupOdd (InterleaveUpper) |
| |
| template <typename T, HWY_IF_T_SIZE(T, 4)> |
| HWY_API Vec512<T> DupOdd(Vec512<T> v) { |
| return Vec512<T>{_mm512_shuffle_epi32(v.raw, _MM_PERM_DDBB)}; |
| } |
| HWY_API Vec512<float> DupOdd(Vec512<float> v) { |
| return Vec512<float>{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_DDBB)}; |
| } |
| |
| template <typename T, HWY_IF_T_SIZE(T, 8)> |
| HWY_API Vec512<T> DupOdd(const Vec512<T> v) { |
| const DFromV<decltype(v)> d; |
| return InterleaveUpper(d, v, v); |
| } |
| |
| // ------------------------------ OddEven (IfThenElse) |
| |
| template <typename T> |
| HWY_API Vec512<T> OddEven(const Vec512<T> a, const Vec512<T> b) { |
| constexpr size_t s = sizeof(T); |
| constexpr int shift = s == 1 ? 0 : s == 2 ? 32 : s == 4 ? 48 : 56; |
| return IfThenElse(Mask512<T>{0x5555555555555555ull >> shift}, b, a); |
| } |
| |
| // -------------------------- InterleaveEven |
| |
| template <class D, HWY_IF_LANES_D(D, 16), HWY_IF_UI32_D(D)> |
| HWY_API VFromD<D> InterleaveEven(D /*d*/, VFromD<D> a, VFromD<D> b) { |
| return VFromD<D>{_mm512_mask_shuffle_epi32( |
| a.raw, static_cast<__mmask16>(0xAAAA), b.raw, |
| static_cast<_MM_PERM_ENUM>(_MM_SHUFFLE(2, 2, 0, 0)))}; |
| } |
| template <class D, HWY_IF_LANES_D(D, 16), HWY_IF_F32_D(D)> |
| HWY_API VFromD<D> InterleaveEven(D /*d*/, VFromD<D> a, VFromD<D> b) { |
| return VFromD<D>{_mm512_mask_shuffle_ps(a.raw, static_cast<__mmask16>(0xAAAA), |
| b.raw, b.raw, |
| _MM_SHUFFLE(2, 2, 0, 0))}; |
| } |
| // -------------------------- InterleaveOdd |
| |
| template <class D, HWY_IF_LANES_D(D, 16), HWY_IF_UI32_D(D)> |
| HWY_API VFromD<D> InterleaveOdd(D /*d*/, VFromD<D> a, VFromD<D> b) { |
| return VFromD<D>{_mm512_mask_shuffle_epi32( |
| b.raw, static_cast<__mmask16>(0x5555), a.raw, |
| static_cast<_MM_PERM_ENUM>(_MM_SHUFFLE(3, 3, 1, 1)))}; |
| } |
| template <class D, HWY_IF_LANES_D(D, 16), HWY_IF_F32_D(D)> |
| HWY_API VFromD<D> InterleaveOdd(D /*d*/, VFromD<D> a, VFromD<D> b) { |
| return VFromD<D>{_mm512_mask_shuffle_ps(b.raw, static_cast<__mmask16>(0x5555), |
| a.raw, a.raw, |
| _MM_SHUFFLE(3, 3, 1, 1))}; |
| } |
| |
| // ------------------------------ OddEvenBlocks |
| |
| template <typename T> |
| HWY_API Vec512<T> OddEvenBlocks(Vec512<T> odd, Vec512<T> even) { |
| const DFromV<decltype(odd)> d; |
| const RebindToUnsigned<decltype(d)> du; // for float16_t |
| return BitCast( |
| d, VFromD<decltype(du)>{_mm512_mask_blend_epi64( |
| __mmask8{0x33u}, BitCast(du, odd).raw, BitCast(du, even).raw)}); |
| } |
| |
| HWY_API Vec512<float> OddEvenBlocks(Vec512<float> odd, Vec512<float> even) { |
| return Vec512<float>{ |
| _mm512_mask_blend_ps(__mmask16{0x0F0Fu}, odd.raw, even.raw)}; |
| } |
| |
| HWY_API Vec512<double> OddEvenBlocks(Vec512<double> odd, Vec512<double> even) { |
| return Vec512<double>{ |
| _mm512_mask_blend_pd(__mmask8{0x33u}, odd.raw, even.raw)}; |
| } |
| |
| // ------------------------------ SwapAdjacentBlocks |
| |
| template <typename T> |
| HWY_API Vec512<T> SwapAdjacentBlocks(Vec512<T> v) { |
| const DFromV<decltype(v)> d; |
| const RebindToUnsigned<decltype(d)> du; // for float16_t |
| return BitCast(d, |
| VFromD<decltype(du)>{_mm512_shuffle_i32x4( |
| BitCast(du, v).raw, BitCast(du, v).raw, _MM_PERM_CDAB)}); |
| } |
| |
| HWY_API Vec512<float> SwapAdjacentBlocks(Vec512<float> v) { |
| return Vec512<float>{_mm512_shuffle_f32x4(v.raw, v.raw, _MM_PERM_CDAB)}; |
| } |
| |
| HWY_API Vec512<double> SwapAdjacentBlocks(Vec512<double> v) { |
| return Vec512<double>{_mm512_shuffle_f64x2(v.raw, v.raw, _MM_PERM_CDAB)}; |
| } |
| |
| // ------------------------------ InterleaveEvenBlocks |
| template <typename T> |
| HWY_API Vec512<T> InterleaveEvenBlocks(Full512<T> d, Vec512<T> a, Vec512<T> b) { |
| return OddEvenBlocks(SlideUpBlocks<1>(d, b), a); |
| } |
| |
| // ------------------------------ InterleaveOddBlocks (ConcatUpperUpper) |
| template <typename T> |
| HWY_API Vec512<T> InterleaveOddBlocks(Full512<T> d, Vec512<T> a, Vec512<T> b) { |
| return OddEvenBlocks(b, SlideDownBlocks<1>(d, a)); |
| } |
| |
| // ------------------------------ ReverseBlocks |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_FLOAT3264_D(D)> |
| HWY_API VFromD<D> ReverseBlocks(D d, VFromD<D> v) { |
| const RebindToUnsigned<decltype(d)> du; // for float16_t |
| return BitCast(d, |
| VFromD<decltype(du)>{_mm512_shuffle_i32x4( |
| BitCast(du, v).raw, BitCast(du, v).raw, _MM_PERM_ABCD)}); |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> |
| HWY_API VFromD<D> ReverseBlocks(D /* tag */, VFromD<D> v) { |
| return VFromD<D>{_mm512_shuffle_f32x4(v.raw, v.raw, _MM_PERM_ABCD)}; |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> |
| HWY_API VFromD<D> ReverseBlocks(D /* tag */, VFromD<D> v) { |
| return VFromD<D>{_mm512_shuffle_f64x2(v.raw, v.raw, _MM_PERM_ABCD)}; |
| } |
| |
| // ------------------------------ TableLookupBytes (ZeroExtendVector) |
| |
| // Both full |
| template <typename T, typename TI> |
| HWY_API Vec512<TI> TableLookupBytes(Vec512<T> bytes, Vec512<TI> indices) { |
| const DFromV<decltype(indices)> d; |
| return BitCast(d, Vec512<uint8_t>{_mm512_shuffle_epi8( |
| BitCast(Full512<uint8_t>(), bytes).raw, |
| BitCast(Full512<uint8_t>(), indices).raw)}); |
| } |
| |
| // Partial index vector |
| template <typename T, typename TI, size_t NI> |
| HWY_API Vec128<TI, NI> TableLookupBytes(Vec512<T> bytes, Vec128<TI, NI> from) { |
| const Full512<TI> d512; |
| const Half<decltype(d512)> d256; |
| const Half<decltype(d256)> d128; |
| // First expand to full 128, then 256, then 512. |
| const Vec128<TI> from_full{from.raw}; |
| const auto from_512 = |
| ZeroExtendVector(d512, ZeroExtendVector(d256, from_full)); |
| const auto tbl_full = TableLookupBytes(bytes, from_512); |
| // Shrink to 256, then 128, then partial. |
| return Vec128<TI, NI>{LowerHalf(d128, LowerHalf(d256, tbl_full)).raw}; |
| } |
| template <typename T, typename TI> |
| HWY_API Vec256<TI> TableLookupBytes(Vec512<T> bytes, Vec256<TI> from) { |
| const DFromV<decltype(from)> dih; |
| const Twice<decltype(dih)> di; |
| const auto from_512 = ZeroExtendVector(di, from); |
| return LowerHalf(dih, TableLookupBytes(bytes, from_512)); |
| } |
| |
| // Partial table vector |
| template <typename T, size_t N, typename TI> |
| HWY_API Vec512<TI> TableLookupBytes(Vec128<T, N> bytes, Vec512<TI> from) { |
| const DFromV<decltype(from)> d512; |
| const Half<decltype(d512)> d256; |
| const Half<decltype(d256)> d128; |
| // First expand to full 128, then 256, then 512. |
| const Vec128<T> bytes_full{bytes.raw}; |
| const auto bytes_512 = |
| ZeroExtendVector(d512, ZeroExtendVector(d256, bytes_full)); |
| return TableLookupBytes(bytes_512, from); |
| } |
| template <typename T, typename TI> |
| HWY_API Vec512<TI> TableLookupBytes(Vec256<T> bytes, Vec512<TI> from) { |
| const Full512<T> d; |
| return TableLookupBytes(ZeroExtendVector(d, bytes), from); |
| } |
| |
| // Partial both are handled by x86_128/256. |
| |
| // ------------------------------ I8/U8 Broadcast (TableLookupBytes) |
| |
| template <int kLane, class T, HWY_IF_T_SIZE(T, 1)> |
| HWY_API Vec512<T> Broadcast(const Vec512<T> v) { |
| static_assert(0 <= kLane && kLane < 16, "Invalid lane"); |
| return TableLookupBytes(v, Set(Full512<T>(), static_cast<T>(kLane))); |
| } |
| |
| // ------------------------------ Per4LaneBlockShuffle |
| |
| namespace detail { |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64)> |
| HWY_INLINE VFromD<D> Per4LaneBlkShufDupSet4xU32(D d, const uint32_t x3, |
| const uint32_t x2, |
| const uint32_t x1, |
| const uint32_t x0) { |
| return BitCast(d, Vec512<uint32_t>{_mm512_set_epi32( |
| static_cast<int32_t>(x3), static_cast<int32_t>(x2), |
| static_cast<int32_t>(x1), static_cast<int32_t>(x0), |
| static_cast<int32_t>(x3), static_cast<int32_t>(x2), |
| static_cast<int32_t>(x1), static_cast<int32_t>(x0), |
| static_cast<int32_t>(x3), static_cast<int32_t>(x2), |
| static_cast<int32_t>(x1), static_cast<int32_t>(x0), |
| static_cast<int32_t>(x3), static_cast<int32_t>(x2), |
| static_cast<int32_t>(x1), static_cast<int32_t>(x0))}); |
| } |
| |
| template <size_t kIdx3210, class V, HWY_IF_NOT_FLOAT(TFromV<V>)> |
| HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<kIdx3210> /*idx_3210_tag*/, |
| hwy::SizeTag<4> /*lane_size_tag*/, |
| hwy::SizeTag<64> /*vect_size_tag*/, V v) { |
| return V{ |
| _mm512_shuffle_epi32(v.raw, static_cast<_MM_PERM_ENUM>(kIdx3210 & 0xFF))}; |
| } |
| |
| template <size_t kIdx3210, class V, HWY_IF_FLOAT(TFromV<V>)> |
| HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<kIdx3210> /*idx_3210_tag*/, |
| hwy::SizeTag<4> /*lane_size_tag*/, |
| hwy::SizeTag<64> /*vect_size_tag*/, V v) { |
| return V{_mm512_shuffle_ps(v.raw, v.raw, static_cast<int>(kIdx3210 & 0xFF))}; |
| } |
| |
| template <size_t kIdx3210, class V, HWY_IF_NOT_FLOAT(TFromV<V>)> |
| HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<kIdx3210> /*idx_3210_tag*/, |
| hwy::SizeTag<8> /*lane_size_tag*/, |
| hwy::SizeTag<64> /*vect_size_tag*/, V v) { |
| return V{_mm512_permutex_epi64(v.raw, static_cast<int>(kIdx3210 & 0xFF))}; |
| } |
| |
| template <size_t kIdx3210, class V, HWY_IF_FLOAT(TFromV<V>)> |
| HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<kIdx3210> /*idx_3210_tag*/, |
| hwy::SizeTag<8> /*lane_size_tag*/, |
| hwy::SizeTag<64> /*vect_size_tag*/, V v) { |
| return V{_mm512_permutex_pd(v.raw, static_cast<int>(kIdx3210 & 0xFF))}; |
| } |
| |
| } // namespace detail |
| |
| // ------------------------------ SlideUpLanes |
| |
| namespace detail { |
| |
| template <int kI32Lanes, class V, HWY_IF_V_SIZE_V(V, 64)> |
| HWY_INLINE V CombineShiftRightI32Lanes(V hi, V lo) { |
| const DFromV<decltype(hi)> d; |
| const Repartition<uint32_t, decltype(d)> du32; |
| return BitCast(d, |
| Vec512<uint32_t>{_mm512_alignr_epi32( |
| BitCast(du32, hi).raw, BitCast(du32, lo).raw, kI32Lanes)}); |
| } |
| |
| template <int kI64Lanes, class V, HWY_IF_V_SIZE_V(V, 64)> |
| HWY_INLINE V CombineShiftRightI64Lanes(V hi, V lo) { |
| const DFromV<decltype(hi)> d; |
| const Repartition<uint64_t, decltype(d)> du64; |
| return BitCast(d, |
| Vec512<uint64_t>{_mm512_alignr_epi64( |
| BitCast(du64, hi).raw, BitCast(du64, lo).raw, kI64Lanes)}); |
| } |
| |
| template <int kI32Lanes, class V, HWY_IF_V_SIZE_V(V, 64)> |
| HWY_INLINE V SlideUpI32Lanes(V v) { |
| static_assert(0 <= kI32Lanes && kI32Lanes <= 15, |
| "kI32Lanes must be between 0 and 15"); |
| const DFromV<decltype(v)> d; |
| return CombineShiftRightI32Lanes<16 - kI32Lanes>(v, Zero(d)); |
| } |
| |
| template <int kI64Lanes, class V, HWY_IF_V_SIZE_V(V, 64)> |
| HWY_INLINE V SlideUpI64Lanes(V v) { |
| static_assert(0 <= kI64Lanes && kI64Lanes <= 7, |
| "kI64Lanes must be between 0 and 7"); |
| const DFromV<decltype(v)> d; |
| return CombineShiftRightI64Lanes<8 - kI64Lanes>(v, Zero(d)); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 1)> |
| HWY_INLINE VFromD<D> TableLookupSlideUpLanes(D d, VFromD<D> v, size_t amt) { |
| const Repartition<uint8_t, decltype(d)> du8; |
| |
| #if HWY_TARGET <= HWY_AVX3_DL |
| const auto byte_idx = Iota(du8, static_cast<uint8_t>(size_t{0} - amt)); |
| return TwoTablesLookupLanes(v, Zero(d), Indices512<TFromD<D>>{byte_idx.raw}); |
| #else |
| const Repartition<uint16_t, decltype(d)> du16; |
| const Repartition<uint64_t, decltype(d)> du64; |
| const auto byte_idx = Iota(du8, static_cast<uint8_t>(size_t{0} - (amt & 15))); |
| const auto blk_u64_idx = |
| Iota(du64, static_cast<uint64_t>(uint64_t{0} - ((amt >> 4) << 1))); |
| |
| const VFromD<D> even_blocks{ |
| _mm512_shuffle_i32x4(v.raw, v.raw, _MM_SHUFFLE(2, 2, 0, 0))}; |
| const VFromD<D> odd_blocks{ |
| _mm512_shuffle_i32x4(v.raw, v.raw, _MM_SHUFFLE(3, 1, 1, 3))}; |
| const auto odd_sel_mask = |
| MaskFromVec(BitCast(d, ShiftLeft<3>(BitCast(du16, byte_idx)))); |
| const auto even_blk_lookup_result = |
| BitCast(d, TableLookupBytes(even_blocks, byte_idx)); |
| const VFromD<D> blockwise_slide_up_result{ |
| _mm512_mask_shuffle_epi8(even_blk_lookup_result.raw, odd_sel_mask.raw, |
| odd_blocks.raw, byte_idx.raw)}; |
| return BitCast(d, TwoTablesLookupLanes( |
| BitCast(du64, blockwise_slide_up_result), Zero(du64), |
| Indices512<uint64_t>{blk_u64_idx.raw})); |
| #endif |
| } |
| |
| } // namespace detail |
| |
| template <int kBlocks, class D, HWY_IF_V_SIZE_D(D, 64)> |
| HWY_API VFromD<D> SlideUpBlocks(D d, VFromD<D> v) { |
| static_assert(0 <= kBlocks && kBlocks <= 3, |
| "kBlocks must be between 0 and 3"); |
| switch (kBlocks) { |
| case 0: |
| return v; |
| case 1: |
| return detail::SlideUpI64Lanes<2>(v); |
| case 2: |
| return ConcatLowerLower(d, v, Zero(d)); |
| case 3: |
| return detail::SlideUpI64Lanes<6>(v); |
| } |
| |
| return v; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 4)> |
| HWY_API VFromD<D> SlideUpLanes(D d, VFromD<D> v, size_t amt) { |
| #if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang |
| if (__builtin_constant_p(amt)) { |
| switch (amt) { |
| case 0: |
| return v; |
| case 1: |
| return detail::SlideUpI32Lanes<1>(v); |
| case 2: |
| return detail::SlideUpI64Lanes<1>(v); |
| case 3: |
| return detail::SlideUpI32Lanes<3>(v); |
| case 4: |
| return detail::SlideUpI64Lanes<2>(v); |
| case 5: |
| return detail::SlideUpI32Lanes<5>(v); |
| case 6: |
| return detail::SlideUpI64Lanes<3>(v); |
| case 7: |
| return detail::SlideUpI32Lanes<7>(v); |
| case 8: |
| return ConcatLowerLower(d, v, Zero(d)); |
| case 9: |
| return detail::SlideUpI32Lanes<9>(v); |
| case 10: |
| return detail::SlideUpI64Lanes<5>(v); |
| case 11: |
| return detail::SlideUpI32Lanes<11>(v); |
| case 12: |
| return detail::SlideUpI64Lanes<6>(v); |
| case 13: |
| return detail::SlideUpI32Lanes<13>(v); |
| case 14: |
| return detail::SlideUpI64Lanes<7>(v); |
| case 15: |
| return detail::SlideUpI32Lanes<15>(v); |
| } |
| } |
| #endif |
| |
| return detail::TableLookupSlideUpLanes(d, v, amt); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 8)> |
| HWY_API VFromD<D> SlideUpLanes(D d, VFromD<D> v, size_t amt) { |
| #if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang |
| if (__builtin_constant_p(amt)) { |
| switch (amt) { |
| case 0: |
| return v; |
| case 1: |
| return detail::SlideUpI64Lanes<1>(v); |
| case 2: |
| return detail::SlideUpI64Lanes<2>(v); |
| case 3: |
| return detail::SlideUpI64Lanes<3>(v); |
| case 4: |
| return ConcatLowerLower(d, v, Zero(d)); |
| case 5: |
| return detail::SlideUpI64Lanes<5>(v); |
| case 6: |
| return detail::SlideUpI64Lanes<6>(v); |
| case 7: |
| return detail::SlideUpI64Lanes<7>(v); |
| } |
| } |
| #endif |
| |
| return detail::TableLookupSlideUpLanes(d, v, amt); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 1)> |
| HWY_API VFromD<D> SlideUpLanes(D d, VFromD<D> v, size_t amt) { |
| #if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang |
| if (__builtin_constant_p(amt)) { |
| if ((amt & 3) == 0) { |
| const Repartition<uint32_t, decltype(d)> du32; |
| return BitCast(d, SlideUpLanes(du32, BitCast(du32, v), amt >> 2)); |
| } else if ((amt & 1) == 0) { |
| const Repartition<uint16_t, decltype(d)> du16; |
| return BitCast( |
| d, detail::TableLookupSlideUpLanes(du16, BitCast(du16, v), amt >> 1)); |
| } |
| #if HWY_TARGET > HWY_AVX3_DL |
| else if (amt <= 63) { // NOLINT(readability/braces) |
| const Repartition<uint64_t, decltype(d)> du64; |
| const size_t blk_u64_slideup_amt = (amt >> 4) << 1; |
| const auto vu64 = BitCast(du64, v); |
| const auto v_hi = |
| BitCast(d, SlideUpLanes(du64, vu64, blk_u64_slideup_amt)); |
| const auto v_lo = |
| (blk_u64_slideup_amt <= 4) |
| ? BitCast(d, SlideUpLanes(du64, vu64, blk_u64_slideup_amt + 2)) |
| : Zero(d); |
| switch (amt & 15) { |
| case 1: |
| return CombineShiftRightBytes<15>(d, v_hi, v_lo); |
| case 3: |
| return CombineShiftRightBytes<13>(d, v_hi, v_lo); |
| case 5: |
| return CombineShiftRightBytes<11>(d, v_hi, v_lo); |
| case 7: |
| return CombineShiftRightBytes<9>(d, v_hi, v_lo); |
| case 9: |
| return CombineShiftRightBytes<7>(d, v_hi, v_lo); |
| case 11: |
| return CombineShiftRightBytes<5>(d, v_hi, v_lo); |
| case 13: |
| return CombineShiftRightBytes<3>(d, v_hi, v_lo); |
| case 15: |
| return CombineShiftRightBytes<1>(d, v_hi, v_lo); |
| } |
| } |
| #endif // HWY_TARGET > HWY_AVX3_DL |
| } |
| #endif |
| |
| return detail::TableLookupSlideUpLanes(d, v, amt); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 2)> |
| HWY_API VFromD<D> SlideUpLanes(D d, VFromD<D> v, size_t amt) { |
| #if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang |
| if (__builtin_constant_p(amt) && (amt & 1) == 0) { |
| const Repartition<uint32_t, decltype(d)> du32; |
| return BitCast(d, SlideUpLanes(du32, BitCast(du32, v), amt >> 1)); |
| } |
| #endif |
| |
| return detail::TableLookupSlideUpLanes(d, v, amt); |
| } |
| |
| // ------------------------------ Slide1Up |
| |
| template <typename D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 1)> |
| HWY_API VFromD<D> Slide1Up(D d, VFromD<D> v) { |
| #if HWY_TARGET <= HWY_AVX3_DL |
| return detail::TableLookupSlideUpLanes(d, v, 1); |
| #else |
| const auto v_lo = detail::SlideUpI64Lanes<2>(v); |
| return CombineShiftRightBytes<15>(d, v, v_lo); |
| #endif |
| } |
| |
| template <typename D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 2)> |
| HWY_API VFromD<D> Slide1Up(D d, VFromD<D> v) { |
| return detail::TableLookupSlideUpLanes(d, v, 1); |
| } |
| |
| template <typename D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 4)> |
| HWY_API VFromD<D> Slide1Up(D /*d*/, VFromD<D> v) { |
| return detail::SlideUpI32Lanes<1>(v); |
| } |
| |
| template <typename D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 8)> |
| HWY_API VFromD<D> Slide1Up(D /*d*/, VFromD<D> v) { |
| return detail::SlideUpI64Lanes<1>(v); |
| } |
| |
| // ------------------------------ SlideDownLanes |
| |
| namespace detail { |
| |
| template <int kI32Lanes, class V, HWY_IF_V_SIZE_V(V, 64)> |
| HWY_INLINE V SlideDownI32Lanes(V v) { |
| static_assert(0 <= kI32Lanes && kI32Lanes <= 15, |
| "kI32Lanes must be between 0 and 15"); |
| const DFromV<decltype(v)> d; |
| return CombineShiftRightI32Lanes<kI32Lanes>(Zero(d), v); |
| } |
| |
| template <int kI64Lanes, class V, HWY_IF_V_SIZE_V(V, 64)> |
| HWY_INLINE V SlideDownI64Lanes(V v) { |
| static_assert(0 <= kI64Lanes && kI64Lanes <= 7, |
| "kI64Lanes must be between 0 and 7"); |
| const DFromV<decltype(v)> d; |
| return CombineShiftRightI64Lanes<kI64Lanes>(Zero(d), v); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 1)> |
| HWY_INLINE VFromD<D> TableLookupSlideDownLanes(D d, VFromD<D> v, size_t amt) { |
| const Repartition<uint8_t, decltype(d)> du8; |
| |
| #if HWY_TARGET <= HWY_AVX3_DL |
| auto byte_idx = Iota(du8, static_cast<uint8_t>(amt)); |
| return TwoTablesLookupLanes(v, Zero(d), Indices512<TFromD<D>>{byte_idx.raw}); |
| #else |
| const Repartition<uint16_t, decltype(d)> du16; |
| const Repartition<uint64_t, decltype(d)> du64; |
| const auto byte_idx = Iota(du8, static_cast<uint8_t>(amt & 15)); |
| const auto blk_u64_idx = Iota(du64, static_cast<uint64_t>(((amt >> 4) << 1))); |
| |
| const VFromD<D> even_blocks{ |
| _mm512_shuffle_i32x4(v.raw, v.raw, _MM_SHUFFLE(0, 2, 2, 0))}; |
| const VFromD<D> odd_blocks{ |
| _mm512_shuffle_i32x4(v.raw, v.raw, _MM_SHUFFLE(3, 3, 1, 1))}; |
| const auto odd_sel_mask = |
| MaskFromVec(BitCast(d, ShiftLeft<3>(BitCast(du16, byte_idx)))); |
| const VFromD<D> even_blk_lookup_result{ |
| _mm512_maskz_shuffle_epi8(static_cast<__mmask64>(0x0000FFFFFFFFFFFFULL), |
| even_blocks.raw, byte_idx.raw)}; |
| const VFromD<D> blockwise_slide_up_result{ |
| _mm512_mask_shuffle_epi8(even_blk_lookup_result.raw, odd_sel_mask.raw, |
| odd_blocks.raw, byte_idx.raw)}; |
| return BitCast(d, TwoTablesLookupLanes( |
| BitCast(du64, blockwise_slide_up_result), Zero(du64), |
| Indices512<uint64_t>{blk_u64_idx.raw})); |
| #endif |
| } |
| |
| } // namespace detail |
| |
| template <int kBlocks, class D, HWY_IF_V_SIZE_D(D, 64)> |
| HWY_API VFromD<D> SlideDownBlocks(D d, VFromD<D> v) { |
| static_assert(0 <= kBlocks && kBlocks <= 3, |
| "kBlocks must be between 0 and 3"); |
| const Half<decltype(d)> dh; |
| switch (kBlocks) { |
| case 0: |
| return v; |
| case 1: |
| return detail::SlideDownI64Lanes<2>(v); |
| case 2: |
| return ZeroExtendVector(d, UpperHalf(dh, v)); |
| case 3: |
| return detail::SlideDownI64Lanes<6>(v); |
| } |
| |
| return v; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 4)> |
| HWY_API VFromD<D> SlideDownLanes(D d, VFromD<D> v, size_t amt) { |
| #if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang |
| if (__builtin_constant_p(amt)) { |
| const Half<decltype(d)> dh; |
| switch (amt) { |
| case 1: |
| return detail::SlideDownI32Lanes<1>(v); |
| case 2: |
| return detail::SlideDownI64Lanes<1>(v); |
| case 3: |
| return detail::SlideDownI32Lanes<3>(v); |
| case 4: |
| return detail::SlideDownI64Lanes<2>(v); |
| case 5: |
| return detail::SlideDownI32Lanes<5>(v); |
| case 6: |
| return detail::SlideDownI64Lanes<3>(v); |
| case 7: |
| return detail::SlideDownI32Lanes<7>(v); |
| case 8: |
| return ZeroExtendVector(d, UpperHalf(dh, v)); |
| case 9: |
| return detail::SlideDownI32Lanes<9>(v); |
| case 10: |
| return detail::SlideDownI64Lanes<5>(v); |
| case 11: |
| return detail::SlideDownI32Lanes<11>(v); |
| case 12: |
| return detail::SlideDownI64Lanes<6>(v); |
| case 13: |
| return detail::SlideDownI32Lanes<13>(v); |
| case 14: |
| return detail::SlideDownI64Lanes<7>(v); |
| case 15: |
| return detail::SlideDownI32Lanes<15>(v); |
| } |
| } |
| #endif |
| |
| return detail::TableLookupSlideDownLanes(d, v, amt); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 8)> |
| HWY_API VFromD<D> SlideDownLanes(D d, VFromD<D> v, size_t amt) { |
| #if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang |
| if (__builtin_constant_p(amt)) { |
| const Half<decltype(d)> dh; |
| switch (amt) { |
| case 0: |
| return v; |
| case 1: |
| return detail::SlideDownI64Lanes<1>(v); |
| case 2: |
| return detail::SlideDownI64Lanes<2>(v); |
| case 3: |
| return detail::SlideDownI64Lanes<3>(v); |
| case 4: |
| return ZeroExtendVector(d, UpperHalf(dh, v)); |
| case 5: |
| return detail::SlideDownI64Lanes<5>(v); |
| case 6: |
| return detail::SlideDownI64Lanes<6>(v); |
| case 7: |
| return detail::SlideDownI64Lanes<7>(v); |
| } |
| } |
| #endif |
| |
| return detail::TableLookupSlideDownLanes(d, v, amt); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 1)> |
| HWY_API VFromD<D> SlideDownLanes(D d, VFromD<D> v, size_t amt) { |
| #if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang |
| if (__builtin_constant_p(amt)) { |
| if ((amt & 3) == 0) { |
| const Repartition<uint32_t, decltype(d)> du32; |
| return BitCast(d, SlideDownLanes(du32, BitCast(du32, v), amt >> 2)); |
| } else if ((amt & 1) == 0) { |
| const Repartition<uint16_t, decltype(d)> du16; |
| return BitCast(d, detail::TableLookupSlideDownLanes( |
| du16, BitCast(du16, v), amt >> 1)); |
| } |
| #if HWY_TARGET > HWY_AVX3_DL |
| else if (amt <= 63) { // NOLINT(readability/braces) |
| const Repartition<uint64_t, decltype(d)> du64; |
| const size_t blk_u64_slidedown_amt = (amt >> 4) << 1; |
| const auto vu64 = BitCast(du64, v); |
| const auto v_lo = |
| BitCast(d, SlideDownLanes(du64, vu64, blk_u64_slidedown_amt)); |
| const auto v_hi = |
| (blk_u64_slidedown_amt <= 4) |
| ? BitCast(d, |
| SlideDownLanes(du64, vu64, blk_u64_slidedown_amt + 2)) |
| : Zero(d); |
| switch (amt & 15) { |
| case 1: |
| return CombineShiftRightBytes<1>(d, v_hi, v_lo); |
| case 3: |
| return CombineShiftRightBytes<3>(d, v_hi, v_lo); |
| case 5: |
| return CombineShiftRightBytes<5>(d, v_hi, v_lo); |
| case 7: |
| return CombineShiftRightBytes<7>(d, v_hi, v_lo); |
| case 9: |
| return CombineShiftRightBytes<9>(d, v_hi, v_lo); |
| case 11: |
| return CombineShiftRightBytes<11>(d, v_hi, v_lo); |
| case 13: |
| return CombineShiftRightBytes<13>(d, v_hi, v_lo); |
| case 15: |
| return CombineShiftRightBytes<15>(d, v_hi, v_lo); |
| } |
| } |
| #endif |
| } |
| #endif |
| |
| return detail::TableLookupSlideDownLanes(d, v, amt); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 2)> |
| HWY_API VFromD<D> SlideDownLanes(D d, VFromD<D> v, size_t amt) { |
| #if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang |
| if (__builtin_constant_p(amt) && (amt & 1) == 0) { |
| const Repartition<uint32_t, decltype(d)> du32; |
| return BitCast(d, SlideDownLanes(du32, BitCast(du32, v), amt >> 1)); |
| } |
| #endif |
| |
| return detail::TableLookupSlideDownLanes(d, v, amt); |
| } |
| |
| // ------------------------------ Slide1Down |
| |
| template <typename D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 1)> |
| HWY_API VFromD<D> Slide1Down(D d, VFromD<D> v) { |
| #if HWY_TARGET <= HWY_AVX3_DL |
| return detail::TableLookupSlideDownLanes(d, v, 1); |
| #else |
| const auto v_hi = detail::SlideDownI64Lanes<2>(v); |
| return CombineShiftRightBytes<1>(d, v_hi, v); |
| #endif |
| } |
| |
| template <typename D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 2)> |
| HWY_API VFromD<D> Slide1Down(D d, VFromD<D> v) { |
| return detail::TableLookupSlideDownLanes(d, v, 1); |
| } |
| |
| template <typename D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 4)> |
| HWY_API VFromD<D> Slide1Down(D /*d*/, VFromD<D> v) { |
| return detail::SlideDownI32Lanes<1>(v); |
| } |
| |
| template <typename D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 8)> |
| HWY_API VFromD<D> Slide1Down(D /*d*/, VFromD<D> v) { |
| return detail::SlideDownI64Lanes<1>(v); |
| } |
| |
| // ================================================== CONVERT |
| |
| // ------------------------------ Promotions (part w/ narrow lanes -> full) |
| |
| // Unsigned: zero-extend. |
| // Note: these have 3 cycle latency; if inputs are already split across the |
| // 128 bit blocks (in their upper/lower halves), then Zip* would be faster. |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_U16_D(D)> |
| HWY_API VFromD<D> PromoteTo(D /* tag */, Vec256<uint8_t> v) { |
| return VFromD<D>{_mm512_cvtepu8_epi16(v.raw)}; |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_U32_D(D)> |
| HWY_API VFromD<D> PromoteTo(D /* tag */, Vec128<uint8_t> v) { |
| return VFromD<D>{_mm512_cvtepu8_epi32(v.raw)}; |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_U32_D(D)> |
| HWY_API VFromD<D> PromoteTo(D /* tag */, Vec256<uint16_t> v) { |
| return VFromD<D>{_mm512_cvtepu16_epi32(v.raw)}; |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_U64_D(D)> |
| HWY_API VFromD<D> PromoteTo(D /* tag */, Vec256<uint32_t> v) { |
| return VFromD<D>{_mm512_cvtepu32_epi64(v.raw)}; |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_U64_D(D)> |
| HWY_API VFromD<D> PromoteTo(D /* tag */, Vec128<uint16_t> v) { |
| return VFromD<D>{_mm512_cvtepu16_epi64(v.raw)}; |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_U64_D(D)> |
| HWY_API VFromD<D> PromoteTo(D /* tag */, Vec64<uint8_t> v) { |
| return VFromD<D>{_mm512_cvtepu8_epi64(v.raw)}; |
| } |
| |
| // Signed: replicate sign bit. |
| // Note: these have 3 cycle latency; if inputs are already split across the |
| // 128 bit blocks (in their upper/lower halves), then ZipUpper/lo followed by |
| // signed shift would be faster. |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_I16_D(D)> |
| HWY_API VFromD<D> PromoteTo(D /* tag */, Vec256<int8_t> v) { |
| return VFromD<D>{_mm512_cvtepi8_epi16(v.raw)}; |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_I32_D(D)> |
| HWY_API VFromD<D> PromoteTo(D /* tag */, Vec128<int8_t> v) { |
| return VFromD<D>{_mm512_cvtepi8_epi32(v.raw)}; |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_I32_D(D)> |
| HWY_API VFromD<D> PromoteTo(D /* tag */, Vec256<int16_t> v) { |
| return VFromD<D>{_mm512_cvtepi16_epi32(v.raw)}; |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_I64_D(D)> |
| HWY_API VFromD<D> PromoteTo(D /* tag */, Vec256<int32_t> v) { |
| return VFromD<D>{_mm512_cvtepi32_epi64(v.raw)}; |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_I64_D(D)> |
| HWY_API VFromD<D> PromoteTo(D /* tag */, Vec128<int16_t> v) { |
| return VFromD<D>{_mm512_cvtepi16_epi64(v.raw)}; |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_I64_D(D)> |
| HWY_API VFromD<D> PromoteTo(D /* tag */, Vec64<int8_t> v) { |
| return VFromD<D>{_mm512_cvtepi8_epi64(v.raw)}; |
| } |
| |
| // Float |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> |
| HWY_API VFromD<D> PromoteTo(D /* tag */, Vec256<float16_t> v) { |
| #if HWY_HAVE_FLOAT16 |
| const RebindToUnsigned<DFromV<decltype(v)>> du16; |
| return VFromD<D>{_mm512_cvtph_ps(BitCast(du16, v).raw)}; |
| #else |
| return VFromD<D>{_mm512_cvtph_ps(v.raw)}; |
| #endif // HWY_HAVE_FLOAT16 |
| } |
| |
| #if HWY_HAVE_FLOAT16 |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> |
| HWY_INLINE VFromD<D> PromoteTo(D /*tag*/, Vec128<float16_t> v) { |
| return VFromD<D>{_mm512_cvtph_pd(v.raw)}; |
| } |
| |
| #endif // HWY_HAVE_FLOAT16 |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> |
| HWY_API VFromD<D> PromoteTo(D df32, Vec256<bfloat16_t> v) { |
| const Rebind<uint16_t, decltype(df32)> du16; |
| const RebindToSigned<decltype(df32)> di32; |
| return BitCast(df32, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> |
| HWY_API VFromD<D> PromoteTo(D /* tag */, Vec256<float> v) { |
| return VFromD<D>{_mm512_cvtps_pd(v.raw)}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> |
| HWY_API VFromD<D> PromoteTo(D /* tag */, Vec256<int32_t> v) { |
| return VFromD<D>{_mm512_cvtepi32_pd(v.raw)}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> |
| HWY_API VFromD<D> PromoteTo(D /* tag */, Vec256<uint32_t> v) { |
| return VFromD<D>{_mm512_cvtepu32_pd(v.raw)}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_I64_D(D)> |
| HWY_API VFromD<D> PromoteInRangeTo(D /*di64*/, VFromD<Rebind<float, D>> v) { |
| #if HWY_COMPILER_GCC_ACTUAL |
| // Workaround for undefined behavior with GCC if any values of v[i] are not |
| // within the range of an int64_t |
| |
| #if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD |
| if (detail::IsConstantX86VecForF2IConv<int64_t>(v)) { |
| typedef float GccF32RawVectType __attribute__((__vector_size__(32))); |
| const auto raw_v = reinterpret_cast<GccF32RawVectType>(v.raw); |
| return VFromD<D>{_mm512_setr_epi64( |
| detail::X86ConvertScalarFromFloat<int64_t>(raw_v[0]), |
| detail::X86ConvertScalarFromFloat<int64_t>(raw_v[1]), |
| detail::X86ConvertScalarFromFloat<int64_t>(raw_v[2]), |
| detail::X86ConvertScalarFromFloat<int64_t>(raw_v[3]), |
| detail::X86ConvertScalarFromFloat<int64_t>(raw_v[4]), |
| detail::X86ConvertScalarFromFloat<int64_t>(raw_v[5]), |
| detail::X86ConvertScalarFromFloat<int64_t>(raw_v[6]), |
| detail::X86ConvertScalarFromFloat<int64_t>(raw_v[7]))}; |
| } |
| #endif |
| |
| __m512i raw_result; |
| __asm__("vcvttps2qq {%1, %0|%0, %1}" |
| : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) |
| : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) |
| :); |
| return VFromD<D>{raw_result}; |
| #else |
| return VFromD<D>{_mm512_cvttps_epi64(v.raw)}; |
| #endif |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_U64_D(D)> |
| HWY_API VFromD<D> PromoteInRangeTo(D /* tag */, VFromD<Rebind<float, D>> v) { |
| #if HWY_COMPILER_GCC_ACTUAL |
| // Workaround for undefined behavior with GCC if any values of v[i] are not |
| // within the range of an uint64_t |
| |
| #if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD |
| if (detail::IsConstantX86VecForF2IConv<int64_t>(v)) { |
| typedef float GccF32RawVectType __attribute__((__vector_size__(32))); |
| const auto raw_v = reinterpret_cast<GccF32RawVectType>(v.raw); |
| return VFromD<D>{_mm512_setr_epi64( |
| static_cast<int64_t>( |
| detail::X86ConvertScalarFromFloat<uint64_t>(raw_v[0])), |
| static_cast<int64_t>( |
| detail::X86ConvertScalarFromFloat<uint64_t>(raw_v[1])), |
| static_cast<int64_t>( |
| detail::X86ConvertScalarFromFloat<uint64_t>(raw_v[2])), |
| static_cast<int64_t>( |
| detail::X86ConvertScalarFromFloat<uint64_t>(raw_v[3])), |
| static_cast<int64_t>( |
| detail::X86ConvertScalarFromFloat<uint64_t>(raw_v[4])), |
| static_cast<int64_t>( |
| detail::X86ConvertScalarFromFloat<uint64_t>(raw_v[5])), |
| static_cast<int64_t>( |
| detail::X86ConvertScalarFromFloat<uint64_t>(raw_v[6])), |
| static_cast<int64_t>( |
| detail::X86ConvertScalarFromFloat<uint64_t>(raw_v[7])))}; |
| } |
| #endif |
| |
| __m512i raw_result; |
| __asm__("vcvttps2uqq {%1, %0|%0, %1}" |
| : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) |
| : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) |
| :); |
| return VFromD<D>{raw_result}; |
| #else |
| return VFromD<D>{_mm512_cvttps_epu64(v.raw)}; |
| #endif |
| } |
| |
| // ------------------------------ Demotions (full -> part w/ narrow lanes) |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U16_D(D)> |
| HWY_API VFromD<D> DemoteTo(D /* tag */, Vec512<int32_t> v) { |
| const Full512<uint64_t> du64; |
| const Vec512<uint16_t> u16{_mm512_packus_epi32(v.raw, v.raw)}; |
| |
| // Compress even u64 lanes into 256 bit. |
| alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; |
| const auto idx64 = Load(du64, kLanes); |
| const Vec512<uint16_t> even{_mm512_permutexvar_epi64(idx64.raw, u16.raw)}; |
| return LowerHalf(even); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U16_D(D)> |
| HWY_API VFromD<D> DemoteTo(D dn, Vec512<uint32_t> v) { |
| const DFromV<decltype(v)> d; |
| const RebindToSigned<decltype(d)> di; |
| return DemoteTo(dn, BitCast(di, Min(v, Set(d, 0x7FFFFFFFu)))); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I16_D(D)> |
| HWY_API VFromD<D> DemoteTo(D /* tag */, Vec512<int32_t> v) { |
| const Full512<uint64_t> du64; |
| const Vec512<int16_t> i16{_mm512_packs_epi32(v.raw, v.raw)}; |
| |
| // Compress even u64 lanes into 256 bit. |
| alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; |
| const auto idx64 = Load(du64, kLanes); |
| const Vec512<int16_t> even{_mm512_permutexvar_epi64(idx64.raw, i16.raw)}; |
| return LowerHalf(even); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_U8_D(D)> |
| HWY_API VFromD<D> DemoteTo(D /* tag */, Vec512<int32_t> v) { |
| const Full512<uint32_t> du32; |
| const Vec512<int16_t> i16{_mm512_packs_epi32(v.raw, v.raw)}; |
| const Vec512<uint8_t> u8{_mm512_packus_epi16(i16.raw, i16.raw)}; |
| |
| const VFromD<decltype(du32)> idx32 = Dup128VecFromValues(du32, 0, 4, 8, 12); |
| const Vec512<uint8_t> fixed{_mm512_permutexvar_epi32(idx32.raw, u8.raw)}; |
| return LowerHalf(LowerHalf(fixed)); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_U8_D(D)> |
| HWY_API VFromD<D> DemoteTo(D /* tag */, Vec512<uint32_t> v) { |
| return VFromD<D>{_mm512_cvtusepi32_epi8(v.raw)}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U8_D(D)> |
| HWY_API VFromD<D> DemoteTo(D /* tag */, Vec512<int16_t> v) { |
| const Full512<uint64_t> du64; |
| const Vec512<uint8_t> u8{_mm512_packus_epi16(v.raw, v.raw)}; |
| |
| // Compress even u64 lanes into 256 bit. |
| alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; |
| const auto idx64 = Load(du64, kLanes); |
| const Vec512<uint8_t> even{_mm512_permutexvar_epi64(idx64.raw, u8.raw)}; |
| return LowerHalf(even); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U8_D(D)> |
| HWY_API VFromD<D> DemoteTo(D dn, Vec512<uint16_t> v) { |
| const DFromV<decltype(v)> d; |
| const RebindToSigned<decltype(d)> di; |
| return DemoteTo(dn, BitCast(di, Min(v, Set(d, 0x7FFFu)))); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_I8_D(D)> |
| HWY_API VFromD<D> DemoteTo(D /* tag */, Vec512<int32_t> v) { |
| const Full512<uint32_t> du32; |
| const Vec512<int16_t> i16{_mm512_packs_epi32(v.raw, v.raw)}; |
| const Vec512<int8_t> i8{_mm512_packs_epi16(i16.raw, i16.raw)}; |
| |
| const VFromD<decltype(du32)> idx32 = Dup128VecFromValues(du32, 0, 4, 8, 12); |
| const Vec512<int8_t> fixed{_mm512_permutexvar_epi32(idx32.raw, i8.raw)}; |
| return LowerHalf(LowerHalf(fixed)); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I8_D(D)> |
| HWY_API VFromD<D> DemoteTo(D /* tag */, Vec512<int16_t> v) { |
| const Full512<uint64_t> du64; |
| const Vec512<int8_t> u8{_mm512_packs_epi16(v.raw, v.raw)}; |
| |
| // Compress even u64 lanes into 256 bit. |
| alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; |
| const auto idx64 = Load(du64, kLanes); |
| const Vec512<int8_t> even{_mm512_permutexvar_epi64(idx64.raw, u8.raw)}; |
| return LowerHalf(even); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I32_D(D)> |
| HWY_API VFromD<D> DemoteTo(D /* tag */, Vec512<int64_t> v) { |
| return VFromD<D>{_mm512_cvtsepi64_epi32(v.raw)}; |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_I16_D(D)> |
| HWY_API VFromD<D> DemoteTo(D /* tag */, Vec512<int64_t> v) { |
| return VFromD<D>{_mm512_cvtsepi64_epi16(v.raw)}; |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 8), HWY_IF_I8_D(D)> |
| HWY_API VFromD<D> DemoteTo(D /* tag */, Vec512<int64_t> v) { |
| return VFromD<D>{_mm512_cvtsepi64_epi8(v.raw)}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U32_D(D)> |
| HWY_API VFromD<D> DemoteTo(D /* tag */, Vec512<int64_t> v) { |
| const __mmask8 non_neg_mask = Not(MaskFromVec(v)).raw; |
| return VFromD<D>{_mm512_maskz_cvtusepi64_epi32(non_neg_mask, v.raw)}; |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_U16_D(D)> |
| HWY_API VFromD<D> DemoteTo(D /* tag */, Vec512<int64_t> v) { |
| const __mmask8 non_neg_mask = Not(MaskFromVec(v)).raw; |
| return VFromD<D>{_mm512_maskz_cvtusepi64_epi16(non_neg_mask, v.raw)}; |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 8), HWY_IF_U8_D(D)> |
| HWY_API VFromD<D> DemoteTo(D /* tag */, Vec512<int64_t> v) { |
| const __mmask8 non_neg_mask = Not(MaskFromVec(v)).raw; |
| return VFromD<D>{_mm512_maskz_cvtusepi64_epi8(non_neg_mask, v.raw)}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U32_D(D)> |
| HWY_API VFromD<D> DemoteTo(D /* tag */, Vec512<uint64_t> v) { |
| return VFromD<D>{_mm512_cvtusepi64_epi32(v.raw)}; |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_U16_D(D)> |
| HWY_API VFromD<D> DemoteTo(D /* tag */, Vec512<uint64_t> v) { |
| return VFromD<D>{_mm512_cvtusepi64_epi16(v.raw)}; |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 8), HWY_IF_U8_D(D)> |
| HWY_API VFromD<D> DemoteTo(D /* tag */, Vec512<uint64_t> v) { |
| return VFromD<D>{_mm512_cvtusepi64_epi8(v.raw)}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F16_D(D)> |
| HWY_API VFromD<D> DemoteTo(D df16, Vec512<float> v) { |
| // Work around warnings in the intrinsic definitions (passing -1 as a mask). |
| HWY_DIAGNOSTICS(push) |
| HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") |
| const RebindToUnsigned<decltype(df16)> du16; |
| return BitCast( |
| df16, VFromD<decltype(du16)>{_mm512_cvtps_ph(v.raw, _MM_FROUND_NO_EXC)}); |
| HWY_DIAGNOSTICS(pop) |
| } |
| |
| #if HWY_HAVE_FLOAT16 |
| template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_F16_D(D)> |
| HWY_API VFromD<D> DemoteTo(D /*df16*/, Vec512<double> v) { |
| return VFromD<D>{_mm512_cvtpd_ph(v.raw)}; |
| } |
| #endif // HWY_HAVE_FLOAT16 |
| |
| #if HWY_AVX3_HAVE_F32_TO_BF16C |
| template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_BF16_D(D)> |
| HWY_API VFromD<D> DemoteTo(D /*dbf16*/, Vec512<float> v) { |
| #if HWY_COMPILER_CLANG >= 1600 && HWY_COMPILER_CLANG < 2000 |
| // Inline assembly workaround for LLVM codegen bug |
| __m256i raw_result; |
| __asm__("vcvtneps2bf16 %1, %0" : "=v"(raw_result) : "v"(v.raw)); |
| return VFromD<D>{raw_result}; |
| #else |
| // The _mm512_cvtneps_pbh intrinsic returns a __m256bh vector that needs to be |
| // bit casted to a __m256i vector |
| return VFromD<D>{detail::BitCastToInteger(_mm512_cvtneps_pbh(v.raw))}; |
| #endif |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_BF16_D(D)> |
| HWY_API VFromD<D> ReorderDemote2To(D /*dbf16*/, Vec512<float> a, |
| Vec512<float> b) { |
| #if HWY_COMPILER_CLANG >= 1600 && HWY_COMPILER_CLANG < 2000 |
| // Inline assembly workaround for LLVM codegen bug |
| __m512i raw_result; |
| __asm__("vcvtne2ps2bf16 %2, %1, %0" |
| : "=v"(raw_result) |
| : "v"(b.raw), "v"(a.raw)); |
| return VFromD<D>{raw_result}; |
| #else |
| // The _mm512_cvtne2ps_pbh intrinsic returns a __m512bh vector that needs to |
| // be bit casted to a __m512i vector |
| return VFromD<D>{detail::BitCastToInteger(_mm512_cvtne2ps_pbh(b.raw, a.raw))}; |
| #endif |
| } |
| #endif // HWY_AVX3_HAVE_F32_TO_BF16C |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_I16_D(D)> |
| HWY_API VFromD<D> ReorderDemote2To(D /* tag */, Vec512<int32_t> a, |
| Vec512<int32_t> b) { |
| return VFromD<D>{_mm512_packs_epi32(a.raw, b.raw)}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_U16_D(D)> |
| HWY_API VFromD<D> ReorderDemote2To(D /* tag */, Vec512<int32_t> a, |
| Vec512<int32_t> b) { |
| return VFromD<D>{_mm512_packus_epi32(a.raw, b.raw)}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_U16_D(D)> |
| HWY_API VFromD<D> ReorderDemote2To(D dn, Vec512<uint32_t> a, |
| Vec512<uint32_t> b) { |
| const DFromV<decltype(a)> du32; |
| const RebindToSigned<decltype(du32)> di32; |
| const auto max_i32 = Set(du32, 0x7FFFFFFFu); |
| |
| return ReorderDemote2To(dn, BitCast(di32, Min(a, max_i32)), |
| BitCast(di32, Min(b, max_i32))); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_I8_D(D)> |
| HWY_API VFromD<D> ReorderDemote2To(D /* tag */, Vec512<int16_t> a, |
| Vec512<int16_t> b) { |
| return VFromD<D>{_mm512_packs_epi16(a.raw, b.raw)}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_U8_D(D)> |
| HWY_API VFromD<D> ReorderDemote2To(D /* tag */, Vec512<int16_t> a, |
| Vec512<int16_t> b) { |
| return VFromD<D>{_mm512_packus_epi16(a.raw, b.raw)}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_U8_D(D)> |
| HWY_API VFromD<D> ReorderDemote2To(D dn, Vec512<uint16_t> a, |
| Vec512<uint16_t> b) { |
| const DFromV<decltype(a)> du16; |
| const RebindToSigned<decltype(du16)> di16; |
| const auto max_i16 = Set(du16, 0x7FFFu); |
| |
| return ReorderDemote2To(dn, BitCast(di16, Min(a, max_i16)), |
| BitCast(di16, Min(b, max_i16))); |
| } |
| |
| template <class D, class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL(TFromD<D>), |
| HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V), |
| HWY_IF_T_SIZE_V(V, sizeof(TFromD<D>) * 2), |
| HWY_IF_LANES_D(D, HWY_MAX_LANES_D(DFromV<V>) * 2), |
| HWY_IF_T_SIZE_ONE_OF_V(V, (1 << 1) | (1 << 2) | (1 << 4))> |
| HWY_API VFromD<D> OrderedDemote2To(D d, V a, V b) { |
| const Full512<uint64_t> du64; |
| alignas(64) static constexpr uint64_t kIdx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; |
| return BitCast(d, TableLookupLanes(BitCast(du64, ReorderDemote2To(d, a, b)), |
| SetTableIndices(du64, kIdx))); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> |
| HWY_API VFromD<D> DemoteTo(D /* tag */, Vec512<double> v) { |
| return VFromD<D>{_mm512_cvtpd_ps(v.raw)}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I32_D(D)> |
| HWY_API VFromD<D> DemoteInRangeTo(D /* tag */, Vec512<double> v) { |
| #if HWY_COMPILER_GCC_ACTUAL |
| // Workaround for undefined behavior in _mm512_cvttpd_epi32 with GCC if any |
| // values of v[i] are not within the range of an int32_t |
| |
| #if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD |
| if (detail::IsConstantX86VecForF2IConv<int32_t>(v)) { |
| typedef double GccF64RawVectType __attribute__((__vector_size__(64))); |
| const auto raw_v = reinterpret_cast<GccF64RawVectType>(v.raw); |
| return VFromD<D>{_mm256_setr_epi32( |
| detail::X86ConvertScalarFromFloat<int32_t>(raw_v[0]), |
| detail::X86ConvertScalarFromFloat<int32_t>(raw_v[1]), |
| detail::X86ConvertScalarFromFloat<int32_t>(raw_v[2]), |
| detail::X86ConvertScalarFromFloat<int32_t>(raw_v[3]), |
| detail::X86ConvertScalarFromFloat<int32_t>(raw_v[4]), |
| detail::X86ConvertScalarFromFloat<int32_t>(raw_v[5]), |
| detail::X86ConvertScalarFromFloat<int32_t>(raw_v[6]), |
| detail::X86ConvertScalarFromFloat<int32_t>(raw_v[7]))}; |
| } |
| #endif |
| |
| __m256i raw_result; |
| __asm__("vcvttpd2dq {%1, %0|%0, %1}" |
| : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) |
| : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) |
| :); |
| return VFromD<D>{raw_result}; |
| #else |
| return VFromD<D>{_mm512_cvttpd_epi32(v.raw)}; |
| #endif |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U32_D(D)> |
| HWY_API VFromD<D> DemoteInRangeTo(D /* tag */, Vec512<double> v) { |
| #if HWY_COMPILER_GCC_ACTUAL |
| // Workaround for undefined behavior in _mm512_cvttpd_epu32 with GCC if any |
| // values of v[i] are not within the range of an uint32_t |
| |
| #if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD |
| if (detail::IsConstantX86VecForF2IConv<uint32_t>(v)) { |
| typedef double GccF64RawVectType __attribute__((__vector_size__(64))); |
| const auto raw_v = reinterpret_cast<GccF64RawVectType>(v.raw); |
| return VFromD<D>{_mm256_setr_epi32( |
| static_cast<int32_t>( |
| detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[0])), |
| static_cast<int32_t>( |
| detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[1])), |
| static_cast<int32_t>( |
| detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[2])), |
| static_cast<int32_t>( |
| detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[3])), |
| static_cast<int32_t>( |
| detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[4])), |
| static_cast<int32_t>( |
| detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[5])), |
| static_cast<int32_t>( |
| detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[6])), |
| static_cast<int32_t>( |
| detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[7])))}; |
| } |
| #endif |
| |
| __m256i raw_result; |
| __asm__("vcvttpd2udq {%1, %0|%0, %1}" |
| : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) |
| : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) |
| :); |
| return VFromD<D>{raw_result}; |
| #else |
| return VFromD<D>{_mm512_cvttpd_epu32(v.raw)}; |
| #endif |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> |
| HWY_API VFromD<D> DemoteTo(D /* tag */, VFromD<Rebind<int64_t, D>> v) { |
| return VFromD<D>{_mm512_cvtepi64_ps(v.raw)}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> |
| HWY_API VFromD<D> DemoteTo(D /* tag */, VFromD<Rebind<uint64_t, D>> v) { |
| return VFromD<D>{_mm512_cvtepu64_ps(v.raw)}; |
| } |
| |
| // For already range-limited input [0, 255]. |
| HWY_API Vec128<uint8_t> U8FromU32(const Vec512<uint32_t> v) { |
| const DFromV<decltype(v)> d32; |
| // In each 128 bit block, gather the lower byte of 4 uint32_t lanes into the |
| // lowest 4 bytes. |
| const VFromD<decltype(d32)> v8From32 = |
| Dup128VecFromValues(d32, 0x0C080400u, ~0u, ~0u, ~0u); |
| const auto quads = TableLookupBytes(v, v8From32); |
| // Gather the lowest 4 bytes of 4 128-bit blocks. |
| const VFromD<decltype(d32)> index32 = Dup128VecFromValues(d32, 0, 4, 8, 12); |
| const Vec512<uint8_t> bytes{_mm512_permutexvar_epi32(index32.raw, quads.raw)}; |
| return LowerHalf(LowerHalf(bytes)); |
| } |
| |
| // ------------------------------ Truncations |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 8), HWY_IF_U8_D(D)> |
| HWY_API VFromD<D> TruncateTo(D d, const Vec512<uint64_t> v) { |
| #if HWY_TARGET <= HWY_AVX3_DL |
| (void)d; |
| const Full512<uint8_t> d8; |
| const VFromD<decltype(d8)> v8From64 = Dup128VecFromValues( |
| d8, 0, 8, 16, 24, 32, 40, 48, 56, 0, 8, 16, 24, 32, 40, 48, 56); |
| const Vec512<uint8_t> bytes{_mm512_permutexvar_epi8(v8From64.raw, v.raw)}; |
| return LowerHalf(LowerHalf(LowerHalf(bytes))); |
| #else |
| const Full512<uint32_t> d32; |
| alignas(64) static constexpr uint32_t kEven[16] = {0, 2, 4, 6, 8, 10, 12, 14, |
| 0, 2, 4, 6, 8, 10, 12, 14}; |
| const Vec512<uint32_t> even{ |
| _mm512_permutexvar_epi32(Load(d32, kEven).raw, v.raw)}; |
| return TruncateTo(d, LowerHalf(even)); |
| #endif |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_U16_D(D)> |
| HWY_API VFromD<D> TruncateTo(D /* tag */, const Vec512<uint64_t> v) { |
| const Full512<uint16_t> d16; |
| alignas(16) static constexpr uint16_t k16From64[8] = {0, 4, 8, 12, |
| 16, 20, 24, 28}; |
| const Vec512<uint16_t> bytes{ |
| _mm512_permutexvar_epi16(LoadDup128(d16, k16From64).raw, v.raw)}; |
| return LowerHalf(LowerHalf(bytes)); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U32_D(D)> |
| HWY_API VFromD<D> TruncateTo(D /* tag */, const Vec512<uint64_t> v) { |
| const Full512<uint32_t> d32; |
| alignas(64) static constexpr uint32_t kEven[16] = {0, 2, 4, 6, 8, 10, 12, 14, |
| 0, 2, 4, 6, 8, 10, 12, 14}; |
| const Vec512<uint32_t> even{ |
| _mm512_permutexvar_epi32(Load(d32, kEven).raw, v.raw)}; |
| return LowerHalf(even); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_U8_D(D)> |
| HWY_API VFromD<D> TruncateTo(D /* tag */, const Vec512<uint32_t> v) { |
| #if HWY_TARGET <= HWY_AVX3_DL |
| const Full512<uint8_t> d8; |
| const VFromD<decltype(d8)> v8From32 = Dup128VecFromValues( |
| d8, 0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60); |
| const Vec512<uint8_t> bytes{_mm512_permutexvar_epi8(v8From32.raw, v.raw)}; |
| #else |
| const Full512<uint32_t> d32; |
| // In each 128 bit block, gather the lower byte of 4 uint32_t lanes into the |
| // lowest 4 bytes. |
| const VFromD<decltype(d32)> v8From32 = |
| Dup128VecFromValues(d32, 0x0C080400u, ~0u, ~0u, ~0u); |
| const auto quads = TableLookupBytes(v, v8From32); |
| // Gather the lowest 4 bytes of 4 128-bit blocks. |
| const VFromD<decltype(d32)> index32 = Dup128VecFromValues(d32, 0, 4, 8, 12); |
| const Vec512<uint8_t> bytes{_mm512_permutexvar_epi32(index32.raw, quads.raw)}; |
| #endif |
| return LowerHalf(LowerHalf(bytes)); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U16_D(D)> |
| HWY_API VFromD<D> TruncateTo(D /* tag */, const Vec512<uint32_t> v) { |
| const Full512<uint16_t> d16; |
| alignas(64) static constexpr uint16_t k16From32[32] = { |
| 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, |
| 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30}; |
| const Vec512<uint16_t> bytes{ |
| _mm512_permutexvar_epi16(Load(d16, k16From32).raw, v.raw)}; |
| return LowerHalf(bytes); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U8_D(D)> |
| HWY_API VFromD<D> TruncateTo(D /* tag */, const Vec512<uint16_t> v) { |
| #if HWY_TARGET <= HWY_AVX3_DL |
| const Full512<uint8_t> d8; |
| alignas(64) static constexpr uint8_t k8From16[64] = { |
| 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, |
| 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62, |
| 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, |
| 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62}; |
| const Vec512<uint8_t> bytes{ |
| _mm512_permutexvar_epi8(Load(d8, k8From16).raw, v.raw)}; |
| #else |
| const Full512<uint32_t> d32; |
| const VFromD<decltype(d32)> v16From32 = Dup128VecFromValues( |
| d32, 0x06040200u, 0x0E0C0A08u, 0x06040200u, 0x0E0C0A08u); |
| const auto quads = TableLookupBytes(v, v16From32); |
| alignas(64) static constexpr uint32_t kIndex32[16] = { |
| 0, 1, 4, 5, 8, 9, 12, 13, 0, 1, 4, 5, 8, 9, 12, 13}; |
| const Vec512<uint8_t> bytes{ |
| _mm512_permutexvar_epi32(Load(d32, kIndex32).raw, quads.raw)}; |
| #endif |
| return LowerHalf(bytes); |
| } |
| |
| // ------------------------------ Convert integer <=> floating point |
| |
| #if HWY_HAVE_FLOAT16 |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F16_D(D)> |
| HWY_API VFromD<D> ConvertTo(D /* tag */, Vec512<uint16_t> v) { |
| return VFromD<D>{_mm512_cvtepu16_ph(v.raw)}; |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F16_D(D)> |
| HWY_API VFromD<D> ConvertTo(D /* tag */, Vec512<int16_t> v) { |
| return VFromD<D>{_mm512_cvtepi16_ph(v.raw)}; |
| } |
| #endif // HWY_HAVE_FLOAT16 |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> |
| HWY_API VFromD<D> ConvertTo(D /* tag */, Vec512<int32_t> v) { |
| return VFromD<D>{_mm512_cvtepi32_ps(v.raw)}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> |
| HWY_API VFromD<D> ConvertTo(D /* tag */, Vec512<int64_t> v) { |
| return VFromD<D>{_mm512_cvtepi64_pd(v.raw)}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> |
| HWY_API VFromD<D> ConvertTo(D /* tag*/, Vec512<uint32_t> v) { |
| return VFromD<D>{_mm512_cvtepu32_ps(v.raw)}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> |
| HWY_API VFromD<D> ConvertTo(D /* tag*/, Vec512<uint64_t> v) { |
| return VFromD<D>{_mm512_cvtepu64_pd(v.raw)}; |
| } |
| |
| // Truncates (rounds toward zero). |
| #if HWY_HAVE_FLOAT16 |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_I16_D(D)> |
| HWY_API VFromD<D> ConvertInRangeTo(D /*d*/, Vec512<float16_t> v) { |
| #if HWY_COMPILER_GCC_ACTUAL |
| // Workaround for undefined behavior in _mm512_cvttph_epi16 with GCC if any |
| // values of v[i] are not within the range of an int16_t |
| |
| #if HWY_COMPILER_GCC_ACTUAL >= 1200 && !HWY_IS_DEBUG_BUILD && \ |
| HWY_HAVE_SCALAR_F16_TYPE |
| if (detail::IsConstantX86VecForF2IConv<int16_t>(v)) { |
| typedef hwy::float16_t::Native GccF16RawVectType |
| __attribute__((__vector_size__(64))); |
| const auto raw_v = reinterpret_cast<GccF16RawVectType>(v.raw); |
| return VFromD<D>{ |
| _mm512_set_epi16(detail::X86ConvertScalarFromFloat<int16_t>(raw_v[31]), |
| detail::X86ConvertScalarFromFloat<int16_t>(raw_v[30]), |
| detail::X86ConvertScalarFromFloat<int16_t>(raw_v[29]), |
| detail::X86ConvertScalarFromFloat<int16_t>(raw_v[28]), |
| detail::X86ConvertScalarFromFloat<int16_t>(raw_v[27]), |
| detail::X86ConvertScalarFromFloat<int16_t>(raw_v[26]), |
| detail::X86ConvertScalarFromFloat<int16_t>(raw_v[25]), |
| detail::X86ConvertScalarFromFloat<int16_t>(raw_v[24]), |
| detail::X86ConvertScalarFromFloat<int16_t>(raw_v[23]), |
| detail::X86ConvertScalarFromFloat<int16_t>(raw_v[22]), |
| detail::X86ConvertScalarFromFloat<int16_t>(raw_v[21]), |
| detail::X86ConvertScalarFromFloat<int16_t>(raw_v[20]), |
| detail::X86ConvertScalarFromFloat<int16_t>(raw_v[19]), |
| detail::X86ConvertScalarFromFloat<int16_t>(raw_v[18]), |
| detail::X86ConvertScalarFromFloat<int16_t>(raw_v[17]), |
| detail::X86ConvertScalarFromFloat<int16_t>(raw_v[16]), |
| detail::X86ConvertScalarFromFloat<int16_t>(raw_v[15]), |
| detail::X86ConvertScalarFromFloat<int16_t>(raw_v[14]), |
| detail::X86ConvertScalarFromFloat<int16_t>(raw_v[13]), |
| detail::X86ConvertScalarFromFloat<int16_t>(raw_v[12]), |
| detail::X86ConvertScalarFromFloat<int16_t>(raw_v[11]), |
| detail::X86ConvertScalarFromFloat<int16_t>(raw_v[10]), |
| detail::X86ConvertScalarFromFloat<int16_t>(raw_v[9]), |
| detail::X86ConvertScalarFromFloat<int16_t>(raw_v[8]), |
| detail::X86ConvertScalarFromFloat<int16_t>(raw_v[7]), |
| detail::X86ConvertScalarFromFloat<int16_t>(raw_v[6]), |
| detail::X86ConvertScalarFromFloat<int16_t>(raw_v[5]), |
| detail::X86ConvertScalarFromFloat<int16_t>(raw_v[4]), |
| detail::X86ConvertScalarFromFloat<int16_t>(raw_v[3]), |
| detail::X86ConvertScalarFromFloat<int16_t>(raw_v[2]), |
| detail::X86ConvertScalarFromFloat<int16_t>(raw_v[1]), |
| detail::X86ConvertScalarFromFloat<int16_t>(raw_v[0]))}; |
| } |
| #endif |
| |
| __m512i raw_result; |
| __asm__("vcvttph2w {%1, %0|%0, %1}" |
| : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) |
| : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) |
| :); |
| return VFromD<D>{raw_result}; |
| #else |
| return VFromD<D>{_mm512_cvttph_epi16(v.raw)}; |
| #endif |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_U16_D(D)> |
| HWY_API VFromD<D> ConvertInRangeTo(D /* tag */, VFromD<RebindToFloat<D>> v) { |
| #if HWY_COMPILER_GCC_ACTUAL |
| // Workaround for undefined behavior in _mm512_cvttph_epu16 with GCC if any |
| // values of v[i] are not within the range of an uint16_t |
| |
| #if HWY_COMPILER_GCC_ACTUAL >= 1200 && !HWY_IS_DEBUG_BUILD && \ |
| HWY_HAVE_SCALAR_F16_TYPE |
| if (detail::IsConstantX86VecForF2IConv<uint16_t>(v)) { |
| typedef hwy::float16_t::Native GccF16RawVectType |
| __attribute__((__vector_size__(64))); |
| const auto raw_v = reinterpret_cast<GccF16RawVectType>(v.raw); |
| return VFromD<D>{_mm512_set_epi16( |
| static_cast<int16_t>( |
| detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[31])), |
| static_cast<int16_t>( |
| detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[30])), |
| static_cast<int16_t>( |
| detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[29])), |
| static_cast<int16_t>( |
| detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[28])), |
| static_cast<int16_t>( |
| detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[27])), |
| static_cast<int16_t>( |
| detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[26])), |
| static_cast<int16_t>( |
| detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[25])), |
| static_cast<int16_t>( |
| detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[24])), |
| static_cast<int16_t>( |
| detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[23])), |
| static_cast<int16_t>( |
| detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[22])), |
| static_cast<int16_t>( |
| detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[21])), |
| static_cast<int16_t>( |
| detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[20])), |
| static_cast<int16_t>( |
| detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[19])), |
| static_cast<int16_t>( |
| detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[18])), |
| static_cast<int16_t>( |
| detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[17])), |
| static_cast<int16_t>( |
| detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[16])), |
| static_cast<int16_t>( |
| detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[15])), |
| static_cast<int16_t>( |
| detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[14])), |
| static_cast<int16_t>( |
| detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[13])), |
| static_cast<int16_t>( |
| detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[12])), |
| static_cast<int16_t>( |
| detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[11])), |
| static_cast<int16_t>( |
| detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[10])), |
| static_cast<int16_t>( |
| detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[9])), |
| static_cast<int16_t>( |
| detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[8])), |
| static_cast<int16_t>( |
| detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[7])), |
| static_cast<int16_t>( |
| detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[6])), |
| static_cast<int16_t>( |
| detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[5])), |
| static_cast<int16_t>( |
| detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[4])), |
| static_cast<int16_t>( |
| detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[3])), |
| static_cast<int16_t>( |
| detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[2])), |
| static_cast<int16_t>( |
| detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[1])), |
| static_cast<int16_t>( |
| detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[0])))}; |
| } |
| #endif |
| |
| __m512i raw_result; |
| __asm__("vcvttph2uw {%1, %0|%0, %1}" |
| : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) |
| : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) |
| :); |
| return VFromD<D>{raw_result}; |
| #else |
| return VFromD<D>{_mm512_cvttph_epu16(v.raw)}; |
| #endif |
| } |
| #endif // HWY_HAVE_FLOAT16 |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_I32_D(D)> |
| HWY_API VFromD<D> ConvertInRangeTo(D /*d*/, Vec512<float> v) { |
| #if HWY_COMPILER_GCC_ACTUAL |
| // Workaround for undefined behavior in _mm512_cvttps_epi32 with GCC if any |
| // values of v[i] are not within the range of an int32_t |
| |
| #if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD |
| if (detail::IsConstantX86VecForF2IConv<int32_t>(v)) { |
| typedef float GccF32RawVectType __attribute__((__vector_size__(64))); |
| const auto raw_v = reinterpret_cast<GccF32RawVectType>(v.raw); |
| return VFromD<D>{_mm512_setr_epi32( |
| detail::X86ConvertScalarFromFloat<int32_t>(raw_v[0]), |
| detail::X86ConvertScalarFromFloat<int32_t>(raw_v[1]), |
| detail::X86ConvertScalarFromFloat<int32_t>(raw_v[2]), |
| detail::X86ConvertScalarFromFloat<int32_t>(raw_v[3]), |
| detail::X86ConvertScalarFromFloat<int32_t>(raw_v[4]), |
| detail::X86ConvertScalarFromFloat<int32_t>(raw_v[5]), |
| detail::X86ConvertScalarFromFloat<int32_t>(raw_v[6]), |
| detail::X86ConvertScalarFromFloat<int32_t>(raw_v[7]), |
| detail::X86ConvertScalarFromFloat<int32_t>(raw_v[8]), |
| detail::X86ConvertScalarFromFloat<int32_t>(raw_v[9]), |
| detail::X86ConvertScalarFromFloat<int32_t>(raw_v[10]), |
| detail::X86ConvertScalarFromFloat<int32_t>(raw_v[11]), |
| detail::X86ConvertScalarFromFloat<int32_t>(raw_v[12]), |
| detail::X86ConvertScalarFromFloat<int32_t>(raw_v[13]), |
| detail::X86ConvertScalarFromFloat<int32_t>(raw_v[14]), |
| detail::X86ConvertScalarFromFloat<int32_t>(raw_v[15]))}; |
| } |
| #endif |
| |
| __m512i raw_result; |
| __asm__("vcvttps2dq {%1, %0|%0, %1}" |
| : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) |
| : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) |
| :); |
| return VFromD<D>{raw_result}; |
| #else |
| return VFromD<D>{_mm512_cvttps_epi32(v.raw)}; |
| #endif |
| } |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_I64_D(D)> |
| HWY_API VFromD<D> ConvertInRangeTo(D /*di*/, Vec512<double> v) { |
| #if HWY_COMPILER_GCC_ACTUAL |
| // Workaround for undefined behavior in _mm512_cvttpd_epi64 with GCC if any |
| // values of v[i] are not within the range of an int64_t |
| |
| #if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD |
| if (detail::IsConstantX86VecForF2IConv<int64_t>(v)) { |
| typedef double GccF64RawVectType __attribute__((__vector_size__(64))); |
| const auto raw_v = reinterpret_cast<GccF64RawVectType>(v.raw); |
| return VFromD<D>{_mm512_setr_epi64( |
| detail::X86ConvertScalarFromFloat<int64_t>(raw_v[0]), |
| detail::X86ConvertScalarFromFloat<int64_t>(raw_v[1]), |
| detail::X86ConvertScalarFromFloat<int64_t>(raw_v[2]), |
| detail::X86ConvertScalarFromFloat<int64_t>(raw_v[3]), |
| detail::X86ConvertScalarFromFloat<int64_t>(raw_v[4]), |
| detail::X86ConvertScalarFromFloat<int64_t>(raw_v[5]), |
| detail::X86ConvertScalarFromFloat<int64_t>(raw_v[6]), |
| detail::X86ConvertScalarFromFloat<int64_t>(raw_v[7]))}; |
| } |
| #endif |
| |
| __m512i raw_result; |
| __asm__("vcvttpd2qq {%1, %0|%0, %1}" |
| : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) |
| : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) |
| :); |
| return VFromD<D>{raw_result}; |
| #else |
| return VFromD<D>{_mm512_cvttpd_epi64(v.raw)}; |
| #endif |
| } |
| template <class DU, HWY_IF_V_SIZE_D(DU, 64), HWY_IF_U32_D(DU)> |
| HWY_API VFromD<DU> ConvertInRangeTo(DU /*du*/, VFromD<RebindToFloat<DU>> v) { |
| #if HWY_COMPILER_GCC_ACTUAL |
| // Workaround for undefined behavior in _mm512_cvttps_epu32 with GCC if any |
| // values of v[i] are not within the range of an uint32_t |
| |
| #if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD |
| if (detail::IsConstantX86VecForF2IConv<uint32_t>(v)) { |
| typedef float GccF32RawVectType __attribute__((__vector_size__(64))); |
| const auto raw_v = reinterpret_cast<GccF32RawVectType>(v.raw); |
| return VFromD<DU>{_mm512_setr_epi32( |
| static_cast<int32_t>( |
| detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[0])), |
| static_cast<int32_t>( |
| detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[1])), |
| static_cast<int32_t>( |
| detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[2])), |
| static_cast<int32_t>( |
| detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[3])), |
| static_cast<int32_t>( |
| detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[4])), |
| static_cast<int32_t>( |
| detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[5])), |
| static_cast<int32_t>( |
| detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[6])), |
| static_cast<int32_t>( |
| detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[7])), |
| static_cast<int32_t>( |
| detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[8])), |
| static_cast<int32_t>( |
| detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[9])), |
| static_cast<int32_t>( |
| detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[10])), |
| static_cast<int32_t>( |
| detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[11])), |
| static_cast<int32_t>( |
| detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[12])), |
| static_cast<int32_t>( |
| detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[13])), |
| static_cast<int32_t>( |
| detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[14])), |
| static_cast<int32_t>( |
| detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[15])))}; |
| } |
| #endif |
| |
| __m512i raw_result; |
| __asm__("vcvttps2udq {%1, %0|%0, %1}" |
| : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) |
| : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) |
| :); |
| return VFromD<DU>{raw_result}; |
| #else |
| return VFromD<DU>{_mm512_cvttps_epu32(v.raw)}; |
| #endif |
| } |
| template <class DU, HWY_IF_V_SIZE_D(DU, 64), HWY_IF_U64_D(DU)> |
| HWY_API VFromD<DU> ConvertInRangeTo(DU /*du*/, VFromD<RebindToFloat<DU>> v) { |
| #if HWY_COMPILER_GCC_ACTUAL |
| // Workaround for undefined behavior in _mm512_cvttpd_epu64 with GCC if any |
| // values of v[i] are not within the range of an uint64_t |
| |
| #if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD |
| if (detail::IsConstantX86VecForF2IConv<int64_t>(v)) { |
| typedef double GccF64RawVectType __attribute__((__vector_size__(64))); |
| const auto raw_v = reinterpret_cast<GccF64RawVectType>(v.raw); |
| return VFromD<DU>{_mm512_setr_epi64( |
| static_cast<int64_t>( |
| detail::X86ConvertScalarFromFloat<uint64_t>(raw_v[0])), |
| static_cast<int64_t>( |
| detail::X86ConvertScalarFromFloat<uint64_t>(raw_v[1])), |
| static_cast<int64_t>( |
| detail::X86ConvertScalarFromFloat<uint64_t>(raw_v[2])), |
| static_cast<int64_t>( |
| detail::X86ConvertScalarFromFloat<uint64_t>(raw_v[3])), |
| static_cast<int64_t>( |
| detail::X86ConvertScalarFromFloat<uint64_t>(raw_v[4])), |
| static_cast<int64_t>( |
| detail::X86ConvertScalarFromFloat<uint64_t>(raw_v[5])), |
| static_cast<int64_t>( |
| detail::X86ConvertScalarFromFloat<uint64_t>(raw_v[6])), |
| static_cast<int64_t>( |
| detail::X86ConvertScalarFromFloat<uint64_t>(raw_v[7])))}; |
| } |
| #endif |
| |
| __m512i raw_result; |
| __asm__("vcvttpd2uqq {%1, %0|%0, %1}" |
| : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) |
| : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) |
| :); |
| return VFromD<DU>{raw_result}; |
| #else |
| return VFromD<DU>{_mm512_cvttpd_epu64(v.raw)}; |
| #endif |
| } |
| |
| template <class DI, HWY_IF_V_SIZE_D(DI, 64), HWY_IF_I32_D(DI)> |
| static HWY_INLINE VFromD<DI> NearestIntInRange(DI, |
| VFromD<RebindToFloat<DI>> v) { |
| #if HWY_COMPILER_GCC_ACTUAL |
| // Workaround for undefined behavior in _mm512_cvtps_epi32 with GCC if any |
| // values of v[i] are not within the range of an int32_t |
| |
| #if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD |
| if (detail::IsConstantX86VecForF2IConv<int32_t>(v)) { |
| typedef float GccF32RawVectType __attribute__((__vector_size__(64))); |
| const auto raw_v = reinterpret_cast<GccF32RawVectType>(v.raw); |
| return VFromD<DI>{ |
| _mm512_setr_epi32(detail::X86ScalarNearestInt<int32_t>(raw_v[0]), |
| detail::X86ScalarNearestInt<int32_t>(raw_v[1]), |
| detail::X86ScalarNearestInt<int32_t>(raw_v[2]), |
| detail::X86ScalarNearestInt<int32_t>(raw_v[3]), |
| detail::X86ScalarNearestInt<int32_t>(raw_v[4]), |
| detail::X86ScalarNearestInt<int32_t>(raw_v[5]), |
| detail::X86ScalarNearestInt<int32_t>(raw_v[6]), |
| detail::X86ScalarNearestInt<int32_t>(raw_v[7]), |
| detail::X86ScalarNearestInt<int32_t>(raw_v[8]), |
| detail::X86ScalarNearestInt<int32_t>(raw_v[9]), |
| detail::X86ScalarNearestInt<int32_t>(raw_v[10]), |
| detail::X86ScalarNearestInt<int32_t>(raw_v[11]), |
| detail::X86ScalarNearestInt<int32_t>(raw_v[12]), |
| detail::X86ScalarNearestInt<int32_t>(raw_v[13]), |
| detail::X86ScalarNearestInt<int32_t>(raw_v[14]), |
| detail::X86ScalarNearestInt<int32_t>(raw_v[15]))}; |
| } |
| #endif |
| |
| __m512i raw_result; |
| __asm__("vcvtps2dq {%1, %0|%0, %1}" |
| : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) |
| : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) |
| :); |
| return VFromD<DI>{raw_result}; |
| #else |
| return VFromD<DI>{_mm512_cvtps_epi32(v.raw)}; |
| #endif |
| } |
| |
| #if HWY_HAVE_FLOAT16 |
| template <class DI, HWY_IF_V_SIZE_D(DI, 64), HWY_IF_I16_D(DI)> |
| static HWY_INLINE VFromD<DI> NearestIntInRange(DI /*d*/, Vec512<float16_t> v) { |
| #if HWY_COMPILER_GCC_ACTUAL |
| // Workaround for undefined behavior in _mm512_cvtph_epi16 with GCC if any |
| // values of v[i] are not within the range of an int16_t |
| |
| #if HWY_COMPILER_GCC_ACTUAL >= 1200 && !HWY_IS_DEBUG_BUILD && \ |
| HWY_HAVE_SCALAR_F16_TYPE |
| if (detail::IsConstantX86VecForF2IConv<int16_t>(v)) { |
| typedef hwy::float16_t::Native GccF16RawVectType |
| __attribute__((__vector_size__(64))); |
| const auto raw_v = reinterpret_cast<GccF16RawVectType>(v.raw); |
| return VFromD<DI>{ |
| _mm512_set_epi16(detail::X86ScalarNearestInt<int16_t>(raw_v[31]), |
| detail::X86ScalarNearestInt<int16_t>(raw_v[30]), |
| detail::X86ScalarNearestInt<int16_t>(raw_v[29]), |
| detail::X86ScalarNearestInt<int16_t>(raw_v[28]), |
| detail::X86ScalarNearestInt<int16_t>(raw_v[27]), |
| detail::X86ScalarNearestInt<int16_t>(raw_v[26]), |
| detail::X86ScalarNearestInt<int16_t>(raw_v[25]), |
| detail::X86ScalarNearestInt<int16_t>(raw_v[24]), |
| detail::X86ScalarNearestInt<int16_t>(raw_v[23]), |
| detail::X86ScalarNearestInt<int16_t>(raw_v[22]), |
| detail::X86ScalarNearestInt<int16_t>(raw_v[21]), |
| detail::X86ScalarNearestInt<int16_t>(raw_v[20]), |
| detail::X86ScalarNearestInt<int16_t>(raw_v[19]), |
| detail::X86ScalarNearestInt<int16_t>(raw_v[18]), |
| detail::X86ScalarNearestInt<int16_t>(raw_v[17]), |
| detail::X86ScalarNearestInt<int16_t>(raw_v[16]), |
| detail::X86ScalarNearestInt<int16_t>(raw_v[15]), |
| detail::X86ScalarNearestInt<int16_t>(raw_v[14]), |
| detail::X86ScalarNearestInt<int16_t>(raw_v[13]), |
| detail::X86ScalarNearestInt<int16_t>(raw_v[12]), |
| detail::X86ScalarNearestInt<int16_t>(raw_v[11]), |
| detail::X86ScalarNearestInt<int16_t>(raw_v[10]), |
| detail::X86ScalarNearestInt<int16_t>(raw_v[9]), |
| detail::X86ScalarNearestInt<int16_t>(raw_v[8]), |
| detail::X86ScalarNearestInt<int16_t>(raw_v[7]), |
| detail::X86ScalarNearestInt<int16_t>(raw_v[6]), |
| detail::X86ScalarNearestInt<int16_t>(raw_v[5]), |
| detail::X86ScalarNearestInt<int16_t>(raw_v[4]), |
| detail::X86ScalarNearestInt<int16_t>(raw_v[3]), |
| detail::X86ScalarNearestInt<int16_t>(raw_v[2]), |
| detail::X86ScalarNearestInt<int16_t>(raw_v[1]), |
| detail::X86ScalarNearestInt<int16_t>(raw_v[0]))}; |
| } |
| #endif |
| |
| __m512i raw_result; |
| __asm__("vcvtph2w {%1, %0|%0, %1}" |
| : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) |
| : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) |
| :); |
| return VFromD<DI>{raw_result}; |
| #else |
| return VFromD<DI>{_mm512_cvtph_epi16(v.raw)}; |
| #endif |
| } |
| #endif // HWY_HAVE_FLOAT16 |
| |
| template <class DI, HWY_IF_V_SIZE_D(DI, 64), HWY_IF_I64_D(DI)> |
| static HWY_INLINE VFromD<DI> NearestIntInRange(DI /*di*/, Vec512<double> v) { |
| #if HWY_COMPILER_GCC_ACTUAL |
| // Workaround for undefined behavior in _mm512_cvtpd_epi64 with GCC if any |
| // values of v[i] are not within the range of an int64_t |
| |
| #if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD |
| if (detail::IsConstantX86VecForF2IConv<int64_t>(v)) { |
| typedef double GccF64RawVectType __attribute__((__vector_size__(64))); |
| const auto raw_v = reinterpret_cast<GccF64RawVectType>(v.raw); |
| return VFromD<DI>{ |
| _mm512_setr_epi64(detail::X86ScalarNearestInt<int64_t>(raw_v[0]), |
| detail::X86ScalarNearestInt<int64_t>(raw_v[1]), |
| detail::X86ScalarNearestInt<int64_t>(raw_v[2]), |
| detail::X86ScalarNearestInt<int64_t>(raw_v[3]), |
| detail::X86ScalarNearestInt<int64_t>(raw_v[4]), |
| detail::X86ScalarNearestInt<int64_t>(raw_v[5]), |
| detail::X86ScalarNearestInt<int64_t>(raw_v[6]), |
| detail::X86ScalarNearestInt<int64_t>(raw_v[7]))}; |
| } |
| #endif |
| |
| __m512i raw_result; |
| __asm__("vcvtpd2qq {%1, %0|%0, %1}" |
| : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) |
| : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) |
| :); |
| return VFromD<DI>{raw_result}; |
| #else |
| return VFromD<DI>{_mm512_cvtpd_epi64(v.raw)}; |
| #endif |
| } |
| |
| template <class DI, HWY_IF_V_SIZE_D(DI, 32), HWY_IF_I32_D(DI)> |
| static HWY_INLINE VFromD<DI> DemoteToNearestIntInRange(DI /* tag */, |
| Vec512<double> v) { |
| #if HWY_COMPILER_GCC_ACTUAL |
| // Workaround for undefined behavior in _mm512_cvtpd_epi32 with GCC if any |
| // values of v[i] are not within the range of an int32_t |
| |
| #if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD |
| if (detail::IsConstantX86VecForF2IConv<int32_t>(v)) { |
| typedef double GccF64RawVectType __attribute__((__vector_size__(64))); |
| const auto raw_v = reinterpret_cast<GccF64RawVectType>(v.raw); |
| return VFromD<DI>{ |
| _mm256_setr_epi32(detail::X86ScalarNearestInt<int32_t>(raw_v[0]), |
| detail::X86ScalarNearestInt<int32_t>(raw_v[1]), |
| detail::X86ScalarNearestInt<int32_t>(raw_v[2]), |
| detail::X86ScalarNearestInt<int32_t>(raw_v[3]), |
| detail::X86ScalarNearestInt<int32_t>(raw_v[4]), |
| detail::X86ScalarNearestInt<int32_t>(raw_v[5]), |
| detail::X86ScalarNearestInt<int32_t>(raw_v[6]), |
| detail::X86ScalarNearestInt<int32_t>(raw_v[7]))}; |
| } |
| #endif |
| |
| __m256i raw_result; |
| __asm__("vcvtpd2dq {%1, %0|%0, %1}" |
| : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) |
| : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) |
| :); |
| return VFromD<DI>{raw_result}; |
| #else |
| return VFromD<DI>{_mm512_cvtpd_epi32(v.raw)}; |
| #endif |
| } |
| |
| // ================================================== CRYPTO |
| |
| #if !defined(HWY_DISABLE_PCLMUL_AES) |
| |
| HWY_API Vec512<uint8_t> AESRound(Vec512<uint8_t> state, |
| Vec512<uint8_t> round_key) { |
| #if HWY_TARGET <= HWY_AVX3_DL |
| return Vec512<uint8_t>{_mm512_aesenc_epi128(state.raw, round_key.raw)}; |
| #else |
| const DFromV<decltype(state)> d; |
| const Half<decltype(d)> d2; |
| return Combine(d, AESRound(UpperHalf(d2, state), UpperHalf(d2, round_key)), |
| AESRound(LowerHalf(state), LowerHalf(round_key))); |
| #endif |
| } |
| |
| HWY_API Vec512<uint8_t> AESLastRound(Vec512<uint8_t> state, |
| Vec512<uint8_t> round_key) { |
| #if HWY_TARGET <= HWY_AVX3_DL |
| return Vec512<uint8_t>{_mm512_aesenclast_epi128(state.raw, round_key.raw)}; |
| #else |
| const DFromV<decltype(state)> d; |
| const Half<decltype(d)> d2; |
| return Combine(d, |
| AESLastRound(UpperHalf(d2, state), UpperHalf(d2, round_key)), |
| AESLastRound(LowerHalf(state), LowerHalf(round_key))); |
| #endif |
| } |
| |
| HWY_API Vec512<uint8_t> AESRoundInv(Vec512<uint8_t> state, |
| Vec512<uint8_t> round_key) { |
| #if HWY_TARGET <= HWY_AVX3_DL |
| return Vec512<uint8_t>{_mm512_aesdec_epi128(state.raw, round_key.raw)}; |
| #else |
| const Full512<uint8_t> d; |
| const Half<decltype(d)> d2; |
| return Combine(d, AESRoundInv(UpperHalf(d2, state), UpperHalf(d2, round_key)), |
| AESRoundInv(LowerHalf(state), LowerHalf(round_key))); |
| #endif |
| } |
| |
| HWY_API Vec512<uint8_t> AESLastRoundInv(Vec512<uint8_t> state, |
| Vec512<uint8_t> round_key) { |
| #if HWY_TARGET <= HWY_AVX3_DL |
| return Vec512<uint8_t>{_mm512_aesdeclast_epi128(state.raw, round_key.raw)}; |
| #else |
| const Full512<uint8_t> d; |
| const Half<decltype(d)> d2; |
| return Combine( |
| d, AESLastRoundInv(UpperHalf(d2, state), UpperHalf(d2, round_key)), |
| AESLastRoundInv(LowerHalf(state), LowerHalf(round_key))); |
| #endif |
| } |
| |
| template <uint8_t kRcon> |
| HWY_API Vec512<uint8_t> AESKeyGenAssist(Vec512<uint8_t> v) { |
| const Full512<uint8_t> d; |
| #if HWY_TARGET <= HWY_AVX3_DL |
| const VFromD<decltype(d)> rconXorMask = Dup128VecFromValues( |
| d, 0, kRcon, 0, 0, 0, 0, 0, 0, 0, kRcon, 0, 0, 0, 0, 0, 0); |
| const VFromD<decltype(d)> rotWordShuffle = Dup128VecFromValues( |
| d, 0, 13, 10, 7, 1, 14, 11, 4, 8, 5, 2, 15, 9, 6, 3, 12); |
| const Repartition<uint32_t, decltype(d)> du32; |
| const auto w13 = BitCast(d, DupOdd(BitCast(du32, v))); |
| const auto sub_word_result = AESLastRound(w13, rconXorMask); |
| return TableLookupBytes(sub_word_result, rotWordShuffle); |
| #else |
| const Half<decltype(d)> d2; |
| return Combine(d, AESKeyGenAssist<kRcon>(UpperHalf(d2, v)), |
| AESKeyGenAssist<kRcon>(LowerHalf(v))); |
| #endif |
| } |
| |
| HWY_API Vec512<uint64_t> CLMulLower(Vec512<uint64_t> va, Vec512<uint64_t> vb) { |
| #if HWY_TARGET <= HWY_AVX3_DL |
| return Vec512<uint64_t>{_mm512_clmulepi64_epi128(va.raw, vb.raw, 0x00)}; |
| #else |
| alignas(64) uint64_t a[8]; |
| alignas(64) uint64_t b[8]; |
| const DFromV<decltype(va)> d; |
| const Half<Half<decltype(d)>> d128; |
| Store(va, d, a); |
| Store(vb, d, b); |
| for (size_t i = 0; i < 8; i += 2) { |
| const auto mul = CLMulLower(Load(d128, a + i), Load(d128, b + i)); |
| Store(mul, d128, a + i); |
| } |
| return Load(d, a); |
| #endif |
| } |
| |
| HWY_API Vec512<uint64_t> CLMulUpper(Vec512<uint64_t> va, Vec512<uint64_t> vb) { |
| #if HWY_TARGET <= HWY_AVX3_DL |
| return Vec512<uint64_t>{_mm512_clmulepi64_epi128(va.raw, vb.raw, 0x11)}; |
| #else |
| alignas(64) uint64_t a[8]; |
| alignas(64) uint64_t b[8]; |
| const DFromV<decltype(va)> d; |
| const Half<Half<decltype(d)>> d128; |
| Store(va, d, a); |
| Store(vb, d, b); |
| for (size_t i = 0; i < 8; i += 2) { |
| const auto mul = CLMulUpper(Load(d128, a + i), Load(d128, b + i)); |
| Store(mul, d128, a + i); |
| } |
| return Load(d, a); |
| #endif |
| } |
| |
| #endif // HWY_DISABLE_PCLMUL_AES |
| |
| // ================================================== MISC |
| |
| // ------------------------------ SumsOfAdjQuadAbsDiff (Broadcast, |
| // SumsOfAdjShufQuadAbsDiff) |
| |
| template <int kAOffset, int kBOffset> |
| static Vec512<uint16_t> SumsOfAdjQuadAbsDiff(Vec512<uint8_t> a, |
| Vec512<uint8_t> b) { |
| static_assert(0 <= kAOffset && kAOffset <= 1, |
| "kAOffset must be between 0 and 1"); |
| static_assert(0 <= kBOffset && kBOffset <= 3, |
| "kBOffset must be between 0 and 3"); |
| |
| const DFromV<decltype(a)> d; |
| const RepartitionToWideX2<decltype(d)> du32; |
| |
| // While AVX3 does not have a _mm512_mpsadbw_epu8 intrinsic, the |
| // SumsOfAdjQuadAbsDiff operation is implementable for 512-bit vectors on |
| // AVX3 using SumsOfShuffledQuadAbsDiff and U32 Broadcast. |
| return SumsOfShuffledQuadAbsDiff<kAOffset + 2, kAOffset + 1, kAOffset + 1, |
| kAOffset>( |
| a, BitCast(d, Broadcast<kBOffset>(BitCast(du32, b)))); |
| } |
| |
| #if !HWY_IS_MSAN |
| // ------------------------------ I32/I64 SaturatedAdd (MaskFromVec) |
| |
| HWY_API Vec512<int32_t> SaturatedAdd(Vec512<int32_t> a, Vec512<int32_t> b) { |
| const DFromV<decltype(a)> d; |
| const auto sum = a + b; |
| const auto overflow_mask = MaskFromVec( |
| Vec512<int32_t>{_mm512_ternarylogic_epi32(a.raw, b.raw, sum.raw, 0x42)}); |
| const auto i32_max = Set(d, LimitsMax<int32_t>()); |
| const Vec512<int32_t> overflow_result{_mm512_mask_ternarylogic_epi32( |
| i32_max.raw, MaskFromVec(a).raw, i32_max.raw, i32_max.raw, 0x55)}; |
| return IfThenElse(overflow_mask, overflow_result, sum); |
| } |
| |
| HWY_API Vec512<int64_t> SaturatedAdd(Vec512<int64_t> a, Vec512<int64_t> b) { |
| const DFromV<decltype(a)> d; |
| const auto sum = a + b; |
| const auto overflow_mask = MaskFromVec( |
| Vec512<int64_t>{_mm512_ternarylogic_epi64(a.raw, b.raw, sum.raw, 0x42)}); |
| const auto i64_max = Set(d, LimitsMax<int64_t>()); |
| const Vec512<int64_t> overflow_result{_mm512_mask_ternarylogic_epi64( |
| i64_max.raw, MaskFromVec(a).raw, i64_max.raw, i64_max.raw, 0x55)}; |
| return IfThenElse(overflow_mask, overflow_result, sum); |
| } |
| |
| // ------------------------------ I32/I64 SaturatedSub (MaskFromVec) |
| |
| HWY_API Vec512<int32_t> SaturatedSub(Vec512<int32_t> a, Vec512<int32_t> b) { |
| const DFromV<decltype(a)> d; |
| const auto diff = a - b; |
| const auto overflow_mask = MaskFromVec( |
| Vec512<int32_t>{_mm512_ternarylogic_epi32(a.raw, b.raw, diff.raw, 0x18)}); |
| const auto i32_max = Set(d, LimitsMax<int32_t>()); |
| const Vec512<int32_t> overflow_result{_mm512_mask_ternarylogic_epi32( |
| i32_max.raw, MaskFromVec(a).raw, i32_max.raw, i32_max.raw, 0x55)}; |
| return IfThenElse(overflow_mask, overflow_result, diff); |
| } |
| |
| HWY_API Vec512<int64_t> SaturatedSub(Vec512<int64_t> a, Vec512<int64_t> b) { |
| const DFromV<decltype(a)> d; |
| const auto diff = a - b; |
| const auto overflow_mask = MaskFromVec( |
| Vec512<int64_t>{_mm512_ternarylogic_epi64(a.raw, b.raw, diff.raw, 0x18)}); |
| const auto i64_max = Set(d, LimitsMax<int64_t>()); |
| const Vec512<int64_t> overflow_result{_mm512_mask_ternarylogic_epi64( |
| i64_max.raw, MaskFromVec(a).raw, i64_max.raw, i64_max.raw, 0x55)}; |
| return IfThenElse(overflow_mask, overflow_result, diff); |
| } |
| #endif // !HWY_IS_MSAN |
| |
| // ------------------------------ Mask testing |
| |
| // Beware: the suffix indicates the number of mask bits, not lane size! |
| |
| namespace detail { |
| |
| template <typename T> |
| HWY_INLINE bool AllFalse(hwy::SizeTag<1> /*tag*/, const Mask512<T> mask) { |
| #if HWY_COMPILER_HAS_MASK_INTRINSICS |
| return _kortestz_mask64_u8(mask.raw, mask.raw); |
| #else |
| return mask.raw == 0; |
| #endif |
| } |
| template <typename T> |
| HWY_INLINE bool AllFalse(hwy::SizeTag<2> /*tag*/, const Mask512<T> mask) { |
| #if HWY_COMPILER_HAS_MASK_INTRINSICS |
| return _kortestz_mask32_u8(mask.raw, mask.raw); |
| #else |
| return mask.raw == 0; |
| #endif |
| } |
| template <typename T> |
| HWY_INLINE bool AllFalse(hwy::SizeTag<4> /*tag*/, const Mask512<T> mask) { |
| #if HWY_COMPILER_HAS_MASK_INTRINSICS |
| return _kortestz_mask16_u8(mask.raw, mask.raw); |
| #else |
| return mask.raw == 0; |
| #endif |
| } |
| template <typename T> |
| HWY_INLINE bool AllFalse(hwy::SizeTag<8> /*tag*/, const Mask512<T> mask) { |
| #if HWY_COMPILER_HAS_MASK_INTRINSICS |
| return _kortestz_mask8_u8(mask.raw, mask.raw); |
| #else |
| return mask.raw == 0; |
| #endif |
| } |
| |
| } // namespace detail |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64)> |
| HWY_API bool AllFalse(D /* tag */, const MFromD<D> mask) { |
| return detail::AllFalse(hwy::SizeTag<sizeof(TFromD<D>)>(), mask); |
| } |
| |
| namespace detail { |
| |
| template <typename T> |
| HWY_INLINE bool AllTrue(hwy::SizeTag<1> /*tag*/, const Mask512<T> mask) { |
| #if HWY_COMPILER_HAS_MASK_INTRINSICS |
| return _kortestc_mask64_u8(mask.raw, mask.raw); |
| #else |
| return mask.raw == 0xFFFFFFFFFFFFFFFFull; |
| #endif |
| } |
| template <typename T> |
| HWY_INLINE bool AllTrue(hwy::SizeTag<2> /*tag*/, const Mask512<T> mask) { |
| #if HWY_COMPILER_HAS_MASK_INTRINSICS |
| return _kortestc_mask32_u8(mask.raw, mask.raw); |
| #else |
| return mask.raw == 0xFFFFFFFFull; |
| #endif |
| } |
| template <typename T> |
| HWY_INLINE bool AllTrue(hwy::SizeTag<4> /*tag*/, const Mask512<T> mask) { |
| #if HWY_COMPILER_HAS_MASK_INTRINSICS |
| return _kortestc_mask16_u8(mask.raw, mask.raw); |
| #else |
| return mask.raw == 0xFFFFull; |
| #endif |
| } |
| template <typename T> |
| HWY_INLINE bool AllTrue(hwy::SizeTag<8> /*tag*/, const Mask512<T> mask) { |
| #if HWY_COMPILER_HAS_MASK_INTRINSICS |
| return _kortestc_mask8_u8(mask.raw, mask.raw); |
| #else |
| return mask.raw == 0xFFull; |
| #endif |
| } |
| |
| } // namespace detail |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64)> |
| HWY_API bool AllTrue(D /* tag */, const MFromD<D> mask) { |
| return detail::AllTrue(hwy::SizeTag<sizeof(TFromD<D>)>(), mask); |
| } |
| |
| // `p` points to at least 8 readable bytes, not all of which need be valid. |
| template <class D, HWY_IF_V_SIZE_D(D, 64)> |
| HWY_API MFromD<D> LoadMaskBits(D /* tag */, const uint8_t* HWY_RESTRICT bits) { |
| MFromD<D> mask; |
| CopyBytes<8 / sizeof(TFromD<D>)>(bits, &mask.raw); |
| // N >= 8 (= 512 / 64), so no need to mask invalid bits. |
| return mask; |
| } |
| |
| // `p` points to at least 8 writable bytes. |
| template <class D, HWY_IF_V_SIZE_D(D, 64)> |
| HWY_API size_t StoreMaskBits(D /* tag */, MFromD<D> mask, uint8_t* bits) { |
| const size_t kNumBytes = 8 / sizeof(TFromD<D>); |
| CopyBytes<kNumBytes>(&mask.raw, bits); |
| // N >= 8 (= 512 / 64), so no need to mask invalid bits. |
| return kNumBytes; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64)> |
| HWY_API size_t CountTrue(D /* tag */, const MFromD<D> mask) { |
| return PopCount(static_cast<uint64_t>(mask.raw)); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_T_SIZE_D(D, 1)> |
| HWY_API size_t FindKnownFirstTrue(D /* tag */, MFromD<D> mask) { |
| return Num0BitsBelowLS1Bit_Nonzero32(mask.raw); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 1)> |
| HWY_API size_t FindKnownFirstTrue(D /* tag */, MFromD<D> mask) { |
| return Num0BitsBelowLS1Bit_Nonzero64(mask.raw); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64)> |
| HWY_API intptr_t FindFirstTrue(D d, MFromD<D> mask) { |
| return mask.raw ? static_cast<intptr_t>(FindKnownFirstTrue(d, mask)) |
| : intptr_t{-1}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_T_SIZE_D(D, 1)> |
| HWY_API size_t FindKnownLastTrue(D /* tag */, MFromD<D> mask) { |
| return 31 - Num0BitsAboveMS1Bit_Nonzero32(mask.raw); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 1)> |
| HWY_API size_t FindKnownLastTrue(D /* tag */, MFromD<D> mask) { |
| return 63 - Num0BitsAboveMS1Bit_Nonzero64(mask.raw); |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64)> |
| HWY_API intptr_t FindLastTrue(D d, MFromD<D> mask) { |
| return mask.raw ? static_cast<intptr_t>(FindKnownLastTrue(d, mask)) |
| : intptr_t{-1}; |
| } |
| |
| // ------------------------------ Compress |
| |
| template <typename T, HWY_IF_T_SIZE(T, 8)> |
| HWY_API Vec512<T> Compress(Vec512<T> v, Mask512<T> mask) { |
| // See CompressIsPartition. u64 is faster than u32. |
| alignas(16) static constexpr uint64_t packed_array[256] = { |
| // From PrintCompress32x8Tables, without the FirstN extension (there is |
| // no benefit to including them because 64-bit CompressStore is anyway |
| // masked, but also no harm because TableLookupLanes ignores the MSB). |
| 0x76543210, 0x76543210, 0x76543201, 0x76543210, 0x76543102, 0x76543120, |
| 0x76543021, 0x76543210, 0x76542103, 0x76542130, 0x76542031, 0x76542310, |
| 0x76541032, 0x76541320, 0x76540321, 0x76543210, 0x76532104, 0x76532140, |
| 0x76532041, 0x76532410, 0x76531042, 0x76531420, 0x76530421, 0x76534210, |
| 0x76521043, 0x76521430, 0x76520431, 0x76524310, 0x76510432, 0x76514320, |
| 0x76504321, 0x76543210, 0x76432105, 0x76432150, 0x76432051, 0x76432510, |
| 0x76431052, 0x76431520, 0x76430521, 0x76435210, 0x76421053, 0x76421530, |
| 0x76420531, 0x76425310, 0x76410532, 0x76415320, 0x76405321, 0x76453210, |
| 0x76321054, 0x76321540, 0x76320541, 0x76325410, 0x76310542, 0x76315420, |
| 0x76305421, 0x76354210, 0x76210543, 0x76215430, 0x76205431, 0x76254310, |
| 0x76105432, 0x76154320, 0x76054321, 0x76543210, 0x75432106, 0x75432160, |
| 0x75432061, 0x75432610, 0x75431062, 0x75431620, 0x75430621, 0x75436210, |
| 0x75421063, 0x75421630, 0x75420631, 0x75426310, 0x75410632, 0x75416320, |
| 0x75406321, 0x75463210, 0x75321064, 0x75321640, 0x75320641, 0x75326410, |
| 0x75310642, 0x75316420, 0x75306421, 0x75364210, 0x75210643, 0x75216430, |
| 0x75206431, 0x75264310, 0x75106432, 0x75164320, 0x75064321, 0x75643210, |
| 0x74321065, 0x74321650, 0x74320651, 0x74326510, 0x74310652, 0x74316520, |
| 0x74306521, 0x74365210, 0x74210653, 0x74216530, 0x74206531, 0x74265310, |
| 0x74106532, 0x74165320, 0x74065321, 0x74653210, 0x73210654, 0x73216540, |
| 0x73206541, 0x73265410, 0x73106542, 0x73165420, 0x73065421, 0x73654210, |
| 0x72106543, 0x72165430, 0x72065431, 0x72654310, 0x71065432, 0x71654320, |
| 0x70654321, 0x76543210, 0x65432107, 0x65432170, 0x65432071, 0x65432710, |
| 0x65431072, 0x65431720, 0x65430721, 0x65437210, 0x65421073, 0x65421730, |
| 0x65420731, 0x65427310, 0x65410732, 0x65417320, 0x65407321, 0x65473210, |
| 0x65321074, 0x65321740, 0x65320741, 0x65327410, 0x65310742, 0x65317420, |
| 0x65307421, 0x65374210, 0x65210743, 0x65217430, 0x65207431, 0x65274310, |
| 0x65107432, 0x65174320, 0x65074321, 0x65743210, 0x64321075, 0x64321750, |
| 0x64320751, 0x64327510, 0x64310752, 0x64317520, 0x64307521, 0x64375210, |
| 0x64210753, 0x64217530, 0x64207531, 0x64275310, 0x64107532, 0x64175320, |
| 0x64075321, 0x64753210, 0x63210754, 0x63217540, 0x63207541, 0x63275410, |
| 0x63107542, 0x63175420, 0x63075421, 0x63754210, 0x62107543, 0x62175430, |
| 0x62075431, 0x62754310, 0x61075432, 0x61754320, 0x60754321, 0x67543210, |
| 0x54321076, 0x54321760, 0x54320761, 0x54327610, 0x54310762, 0x54317620, |
| 0x54307621, 0x54376210, 0x54210763, 0x54217630, 0x54207631, 0x54276310, |
| 0x54107632, 0x54176320, 0x54076321, 0x54763210, 0x53210764, 0x53217640, |
| 0x53207641, 0x53276410, 0x53107642, 0x53176420, 0x53076421, 0x53764210, |
| 0x52107643, 0x52176430, 0x52076431, 0x52764310, 0x51076432, 0x51764320, |
| 0x50764321, 0x57643210, 0x43210765, 0x43217650, 0x43207651, 0x43276510, |
| 0x43107652, 0x43176520, 0x43076521, 0x43765210, 0x42107653, 0x42176530, |
| 0x42076531, 0x42765310, 0x41076532, 0x41765320, 0x40765321, 0x47653210, |
| 0x32107654, 0x32176540, 0x32076541, 0x32765410, 0x31076542, 0x31765420, |
| 0x30765421, 0x37654210, 0x21076543, 0x21765430, 0x20765431, 0x27654310, |
| 0x10765432, 0x17654320, 0x07654321, 0x76543210}; |
| |
| // For lane i, shift the i-th 4-bit index down to bits [0, 3) - |
| // _mm512_permutexvar_epi64 will ignore the upper bits. |
| const DFromV<decltype(v)> d; |
| const RebindToUnsigned<decltype(d)> du64; |
| const auto packed = Set(du64, packed_array[mask.raw]); |
| alignas(64) static constexpr uint64_t shifts[8] = {0, 4, 8, 12, |
| 16, 20, 24, 28}; |
| const auto indices = Indices512<T>{(packed >> Load(du64, shifts)).raw}; |
| return TableLookupLanes(v, indices); |
| } |
| |
| // ------------------------------ Expand |
| |
| namespace detail { |
| |
| #if HWY_TARGET <= HWY_AVX3_DL // VBMI2 |
| HWY_INLINE Vec512<uint8_t> NativeExpand(Vec512<uint8_t> v, |
| Mask512<uint8_t> mask) { |
| return Vec512<uint8_t>{_mm512_maskz_expand_epi8(mask.raw, v.raw)}; |
| } |
| |
| HWY_INLINE Vec512<uint16_t> NativeExpand(Vec512<uint16_t> v, |
| Mask512<uint16_t> mask) { |
| return Vec512<uint16_t>{_mm512_maskz_expand_epi16(mask.raw, v.raw)}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_U8_D(D)> |
| HWY_INLINE VFromD<D> NativeLoadExpand(Mask512<uint8_t> mask, D /* d */, |
| const uint8_t* HWY_RESTRICT unaligned) { |
| return VFromD<D>{_mm512_maskz_expandloadu_epi8(mask.raw, unaligned)}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_U16_D(D)> |
| HWY_INLINE VFromD<D> NativeLoadExpand(Mask512<uint16_t> mask, D /* d */, |
| const uint16_t* HWY_RESTRICT unaligned) { |
| return VFromD<D>{_mm512_maskz_expandloadu_epi16(mask.raw, unaligned)}; |
| } |
| #endif // HWY_TARGET <= HWY_AVX3_DL |
| |
| HWY_INLINE Vec512<uint32_t> NativeExpand(Vec512<uint32_t> v, |
| Mask512<uint32_t> mask) { |
| return Vec512<uint32_t>{_mm512_maskz_expand_epi32(mask.raw, v.raw)}; |
| } |
| |
| HWY_INLINE Vec512<uint64_t> NativeExpand(Vec512<uint64_t> v, |
| Mask512<uint64_t> mask) { |
| return Vec512<uint64_t>{_mm512_maskz_expand_epi64(mask.raw, v.raw)}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_U32_D(D)> |
| HWY_INLINE VFromD<D> NativeLoadExpand(Mask512<uint32_t> mask, D /* d */, |
| const uint32_t* HWY_RESTRICT unaligned) { |
| return VFromD<D>{_mm512_maskz_expandloadu_epi32(mask.raw, unaligned)}; |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_U64_D(D)> |
| HWY_INLINE VFromD<D> NativeLoadExpand(Mask512<uint64_t> mask, D /* d */, |
| const uint64_t* HWY_RESTRICT unaligned) { |
| return VFromD<D>{_mm512_maskz_expandloadu_epi64(mask.raw, unaligned)}; |
| } |
| |
| } // namespace detail |
| |
| template <typename T, HWY_IF_T_SIZE(T, 1)> |
| HWY_API Vec512<T> Expand(Vec512<T> v, const Mask512<T> mask) { |
| const Full512<T> d; |
| #if HWY_TARGET <= HWY_AVX3_DL // VBMI2 |
| const RebindToUnsigned<decltype(d)> du; |
| const auto mu = RebindMask(du, mask); |
| return BitCast(d, detail::NativeExpand(BitCast(du, v), mu)); |
| #else |
| // LUTs are infeasible for 2^64 possible masks, so splice together two |
| // half-vector Expand. |
| const Full256<T> dh; |
| constexpr size_t N = MaxLanes(d); |
| // We have to shift the input by a variable number of u8. Shuffling requires |
| // VBMI2, in which case we would already have NativeExpand. We instead |
| // load at an offset, which may incur a store to load forwarding stall. |
| alignas(64) T lanes[N]; |
| Store(v, d, lanes); |
| using Bits = typename Mask256<T>::Raw; |
| const Mask256<T> maskL{ |
| static_cast<Bits>(mask.raw & Bits{(1ULL << (N / 2)) - 1})}; |
| const Mask256<T> maskH{static_cast<Bits>(mask.raw >> (N / 2))}; |
| const size_t countL = CountTrue(dh, maskL); |
| const Vec256<T> expandL = Expand(LowerHalf(v), maskL); |
| const Vec256<T> expandH = Expand(LoadU(dh, lanes + countL), maskH); |
| return Combine(d, expandH, expandL); |
| #endif |
| } |
| |
| template <typename T, HWY_IF_T_SIZE(T, 2)> |
| HWY_API Vec512<T> Expand(Vec512<T> v, const Mask512<T> mask) { |
| const Full512<T> d; |
| const RebindToUnsigned<decltype(d)> du; |
| const Vec512<uint16_t> vu = BitCast(du, v); |
| #if HWY_TARGET <= HWY_AVX3_DL // VBMI2 |
| return BitCast(d, detail::NativeExpand(vu, RebindMask(du, mask))); |
| #else // AVX3 |
| // LUTs are infeasible for 2^32 possible masks, so splice together two |
| // half-vector Expand. |
| const Full256<T> dh; |
| HWY_LANES_CONSTEXPR size_t N = Lanes(d); |
| using Bits = typename Mask256<T>::Raw; |
| const Mask256<T> maskL{ |
| static_cast<Bits>(mask.raw & static_cast<Bits>((1ULL << (N / 2)) - 1))}; |
| const Mask256<T> maskH{static_cast<Bits>(mask.raw >> (N / 2))}; |
| // In AVX3 we can permutevar, which avoids a potential store to load |
| // forwarding stall vs. reloading the input. |
| alignas(64) uint16_t iota[64] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, |
| 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, |
| 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}; |
| const Vec512<uint16_t> indices = LoadU(du, iota + CountTrue(dh, maskL)); |
| const Vec512<uint16_t> shifted{_mm512_permutexvar_epi16(indices.raw, vu.raw)}; |
| const Vec256<T> expandL = Expand(LowerHalf(v), maskL); |
| const Vec256<T> expandH = Expand(LowerHalf(BitCast(d, shifted)), maskH); |
| return Combine(d, expandH, expandL); |
| #endif // AVX3 |
| } |
| |
| template <class V, class M, HWY_IF_T_SIZE_ONE_OF_V(V, (1 << 4) | (1 << 8))> |
| HWY_API V Expand(V v, const M mask) { |
| const DFromV<decltype(v)> d; |
| const RebindToUnsigned<decltype(d)> du; |
| const auto mu = RebindMask(du, mask); |
| return BitCast(d, detail::NativeExpand(BitCast(du, v), mu)); |
| } |
| |
| // For smaller vectors, it is likely more efficient to promote to 32-bit. |
| // This works for u8x16, u16x8, u16x16 (can be promoted to u32x16), but is |
| // unnecessary if HWY_AVX3_DL, which provides native instructions. |
| #if HWY_TARGET > HWY_AVX3_DL // no VBMI2 |
| |
| template <class V, class M, HWY_IF_T_SIZE_ONE_OF_V(V, (1 << 1) | (1 << 2)), |
| HWY_IF_LANES_LE_D(DFromV<V>, 16)> |
| HWY_API V Expand(V v, M mask) { |
| const DFromV<V> d; |
| const RebindToUnsigned<decltype(d)> du; |
| const Rebind<uint32_t, decltype(d)> du32; |
| const VFromD<decltype(du)> vu = BitCast(du, v); |
| using M32 = MFromD<decltype(du32)>; |
| const M32 m32{static_cast<typename M32::Raw>(mask.raw)}; |
| return BitCast(d, TruncateTo(du, Expand(PromoteTo(du32, vu), m32))); |
| } |
| |
| #endif // HWY_TARGET > HWY_AVX3_DL |
| |
| // ------------------------------ LoadExpand |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), |
| HWY_IF_T_SIZE_ONE_OF_D(D, (1 << 1) | (1 << 2))> |
| HWY_API VFromD<D> LoadExpand(MFromD<D> mask, D d, |
| const TFromD<D>* HWY_RESTRICT unaligned) { |
| #if HWY_TARGET <= HWY_AVX3_DL // VBMI2 |
| const RebindToUnsigned<decltype(d)> du; |
| using TU = TFromD<decltype(du)>; |
| const TU* HWY_RESTRICT pu = reinterpret_cast<const TU*>(unaligned); |
| const MFromD<decltype(du)> mu = RebindMask(du, mask); |
| return BitCast(d, detail::NativeLoadExpand(mu, du, pu)); |
| #else |
| return Expand(LoadU(d, unaligned), mask); |
| #endif |
| } |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), |
| HWY_IF_T_SIZE_ONE_OF_D(D, (1 << 4) | (1 << 8))> |
| HWY_API VFromD<D> LoadExpand(MFromD<D> mask, D d, |
| const TFromD<D>* HWY_RESTRICT unaligned) { |
| const RebindToUnsigned<decltype(d)> du; |
| using TU = TFromD<decltype(du)>; |
| const TU* HWY_RESTRICT pu = reinterpret_cast<const TU*>(unaligned); |
| const MFromD<decltype(du)> mu = RebindMask(du, mask); |
| return BitCast(d, detail::NativeLoadExpand(mu, du, pu)); |
| } |
| |
| // ------------------------------ CompressNot |
| |
| template <typename T, HWY_IF_T_SIZE(T, 8)> |
| HWY_API Vec512<T> CompressNot(Vec512<T> v, Mask512<T> mask) { |
| // See CompressIsPartition. u64 is faster than u32. |
| alignas(16) static constexpr uint64_t packed_array[256] = { |
| // From PrintCompressNot32x8Tables, without the FirstN extension (there is |
| // no benefit to including them because 64-bit CompressStore is anyway |
| // masked, but also no harm because TableLookupLanes ignores the MSB). |
| 0x76543210, 0x07654321, 0x17654320, 0x10765432, 0x27654310, 0x20765431, |
| 0x21765430, 0x21076543, 0x37654210, 0x30765421, 0x31765420, 0x31076542, |
| 0x32765410, 0x32076541, 0x32176540, 0x32107654, 0x47653210, 0x40765321, |
| 0x41765320, 0x41076532, 0x42765310, 0x42076531, 0x42176530, 0x42107653, |
| 0x43765210, 0x43076521, 0x43176520, 0x43107652, 0x43276510, 0x43207651, |
| 0x43217650, 0x43210765, 0x57643210, 0x50764321, 0x51764320, 0x51076432, |
| 0x52764310, 0x52076431, 0x52176430, 0x52107643, 0x53764210, 0x53076421, |
| 0x53176420, 0x53107642, 0x53276410, 0x53207641, 0x53217640, 0x53210764, |
| 0x54763210, 0x54076321, 0x54176320, 0x54107632, 0x54276310, 0x54207631, |
| 0x54217630, 0x54210763, 0x54376210, 0x54307621, 0x54317620, 0x54310762, |
| 0x54327610, 0x54320761, 0x54321760, 0x54321076, 0x67543210, 0x60754321, |
| 0x61754320, 0x61075432, 0x62754310, 0x62075431, 0x62175430, 0x62107543, |
| 0x63754210, 0x63075421, 0x63175420, 0x63107542, 0x63275410, 0x63207541, |
| 0x63217540, 0x63210754, 0x64753210, 0x64075321, 0x64175320, 0x64107532, |
| 0x64275310, 0x64207531, 0x64217530, 0x64210753, 0x64375210, 0x64307521, |
| 0x64317520, 0x64310752, 0x64327510, 0x64320751, 0x64321750, 0x64321075, |
| 0x65743210, 0x65074321, 0x65174320, 0x65107432, 0x65274310, 0x65207431, |
| 0x65217430, 0x65210743, 0x65374210, 0x65307421, 0x65317420, 0x65310742, |
| 0x65327410, 0x65320741, 0x65321740, 0x65321074, 0x65473210, 0x65407321, |
| 0x65417320, 0x65410732, 0x65427310, 0x65420731, 0x65421730, 0x65421073, |
| 0x65437210, 0x65430721, 0x65431720, 0x65431072, 0x65432710, 0x65432071, |
| 0x65432170, 0x65432107, 0x76543210, 0x70654321, 0x71654320, 0x71065432, |
| 0x72654310, 0x72065431, 0x72165430, 0x72106543, 0x73654210, 0x73065421, |
| 0x73165420, 0x73106542, 0x73265410, 0x73206541, 0x73216540, 0x73210654, |
| 0x74653210, 0x74065321, 0x74165320, 0x74106532, 0x74265310, 0x74206531, |
| 0x74216530, 0x74210653, 0x74365210, 0x74306521, 0x74316520, 0x74310652, |
| 0x74326510, 0x74320651, 0x74321650, 0x74321065, 0x75643210, 0x75064321, |
| 0x75164320, 0x75106432, 0x75264310, 0x75206431, 0x75216430, 0x75210643, |
| 0x75364210, 0x75306421, 0x75316420, 0x75310642, 0x75326410, 0x75320641, |
| 0x75321640, 0x75321064, 0x75463210, 0x75406321, 0x75416320, 0x75410632, |
| 0x75426310, 0x75420631, 0x75421630, 0x75421063, 0x75436210, 0x75430621, |
| 0x75431620, 0x75431062, 0x75432610, 0x75432061, 0x75432160, 0x75432106, |
| 0x76543210, 0x76054321, 0x76154320, 0x76105432, 0x76254310, 0x76205431, |
| 0x76215430, 0x76210543, 0x76354210, 0x76305421, 0x76315420, 0x76310542, |
| 0x76325410, 0x76320541, 0x76321540, 0x76321054, 0x76453210, 0x76405321, |
| 0x76415320, 0x76410532, 0x76425310, 0x76420531, 0x76421530, 0x76421053, |
| 0x76435210, 0x76430521, 0x76431520, 0x76431052, 0x76432510, 0x76432051, |
| 0x76432150, 0x76432105, 0x76543210, 0x76504321, 0x76514320, 0x76510432, |
| 0x76524310, 0x76520431, 0x76521430, 0x76521043, 0x76534210, 0x76530421, |
| 0x76531420, 0x76531042, 0x76532410, 0x76532041, 0x76532140, 0x76532104, |
| 0x76543210, 0x76540321, 0x76541320, 0x76541032, 0x76542310, 0x76542031, |
| 0x76542130, 0x76542103, 0x76543210, 0x76543021, 0x76543120, 0x76543102, |
| 0x76543210, 0x76543201, 0x76543210, 0x76543210}; |
| |
| // For lane i, shift the i-th 4-bit index down to bits [0, 3) - |
| // _mm512_permutexvar_epi64 will ignore the upper bits. |
| const DFromV<decltype(v)> d; |
| const RebindToUnsigned<decltype(d)> du64; |
| const auto packed = Set(du64, packed_array[mask.raw]); |
| alignas(64) static constexpr uint64_t shifts[8] = {0, 4, 8, 12, |
| 16, 20, 24, 28}; |
| const auto indices = Indices512<T>{(packed >> Load(du64, shifts)).raw}; |
| return TableLookupLanes(v, indices); |
| } |
| |
| // ------------------------------ LoadInterleaved4 |
| |
| // Actually implemented in generic_ops, we just overload LoadTransposedBlocks4. |
| namespace detail { |
| |
| // Type-safe wrapper. |
| template <_MM_PERM_ENUM kPerm, typename T> |
| Vec512<T> Shuffle128(const Vec512<T> lo, const Vec512<T> hi) { |
| const DFromV<decltype(lo)> d; |
| const RebindToUnsigned<decltype(d)> du; |
| return BitCast(d, VFromD<decltype(du)>{_mm512_shuffle_i64x2( |
| BitCast(du, lo).raw, BitCast(du, hi).raw, kPerm)}); |
| } |
| template <_MM_PERM_ENUM kPerm> |
| Vec512<float> Shuffle128(const Vec512<float> lo, const Vec512<float> hi) { |
| return Vec512<float>{_mm512_shuffle_f32x4(lo.raw, hi.raw, kPerm)}; |
| } |
| template <_MM_PERM_ENUM kPerm> |
| Vec512<double> Shuffle128(const Vec512<double> lo, const Vec512<double> hi) { |
| return Vec512<double>{_mm512_shuffle_f64x2(lo.raw, hi.raw, kPerm)}; |
| } |
| |
| // Input (128-bit blocks): |
| // 3 2 1 0 (<- first block in unaligned) |
| // 7 6 5 4 |
| // b a 9 8 |
| // Output: |
| // 9 6 3 0 (LSB of A) |
| // a 7 4 1 |
| // b 8 5 2 |
| template <class D, HWY_IF_V_SIZE_D(D, 64)> |
| HWY_API void LoadTransposedBlocks3(D d, const TFromD<D>* HWY_RESTRICT unaligned, |
| VFromD<D>& A, VFromD<D>& B, VFromD<D>& C) { |
| HWY_LANES_CONSTEXPR size_t N = Lanes(d); |
| const VFromD<D> v3210 = LoadU(d, unaligned + 0 * N); |
| const VFromD<D> v7654 = LoadU(d, unaligned + 1 * N); |
| const VFromD<D> vba98 = LoadU(d, unaligned + 2 * N); |
| |
| const VFromD<D> v5421 = detail::Shuffle128<_MM_PERM_BACB>(v3210, v7654); |
| const VFromD<D> va976 = detail::Shuffle128<_MM_PERM_CBDC>(v7654, vba98); |
| |
| A = detail::Shuffle128<_MM_PERM_CADA>(v3210, va976); |
| B = detail::Shuffle128<_MM_PERM_DBCA>(v5421, va976); |
| C = detail::Shuffle128<_MM_PERM_DADB>(v5421, vba98); |
| } |
| |
| // Input (128-bit blocks): |
| // 3 2 1 0 (<- first block in unaligned) |
| // 7 6 5 4 |
| // b a 9 8 |
| // f e d c |
| // Output: |
| // c 8 4 0 (LSB of A) |
| // d 9 5 1 |
| // e a 6 2 |
| // f b 7 3 |
| template <class D, HWY_IF_V_SIZE_D(D, 64)> |
| HWY_API void LoadTransposedBlocks4(D d, const TFromD<D>* HWY_RESTRICT unaligned, |
| VFromD<D>& vA, VFromD<D>& vB, VFromD<D>& vC, |
| VFromD<D>& vD) { |
| HWY_LANES_CONSTEXPR size_t N = Lanes(d); |
| const VFromD<D> v3210 = LoadU(d, unaligned + 0 * N); |
| const VFromD<D> v7654 = LoadU(d, unaligned + 1 * N); |
| const VFromD<D> vba98 = LoadU(d, unaligned + 2 * N); |
| const VFromD<D> vfedc = LoadU(d, unaligned + 3 * N); |
| |
| const VFromD<D> v5410 = detail::Shuffle128<_MM_PERM_BABA>(v3210, v7654); |
| const VFromD<D> vdc98 = detail::Shuffle128<_MM_PERM_BABA>(vba98, vfedc); |
| const VFromD<D> v7632 = detail::Shuffle128<_MM_PERM_DCDC>(v3210, v7654); |
| const VFromD<D> vfeba = detail::Shuffle128<_MM_PERM_DCDC>(vba98, vfedc); |
| vA = detail::Shuffle128<_MM_PERM_CACA>(v5410, vdc98); |
| vB = detail::Shuffle128<_MM_PERM_DBDB>(v5410, vdc98); |
| vC = detail::Shuffle128<_MM_PERM_CACA>(v7632, vfeba); |
| vD = detail::Shuffle128<_MM_PERM_DBDB>(v7632, vfeba); |
| } |
| |
| } // namespace detail |
| |
| // ------------------------------ StoreInterleaved2 |
| |
| // Implemented in generic_ops, we just overload StoreTransposedBlocks2/3/4. |
| |
| namespace detail { |
| |
| // Input (128-bit blocks): |
| // 6 4 2 0 (LSB of i) |
| // 7 5 3 1 |
| // Output: |
| // 3 2 1 0 |
| // 7 6 5 4 |
| template <class D, HWY_IF_V_SIZE_D(D, 64)> |
| HWY_API void StoreTransposedBlocks2(const VFromD<D> i, const VFromD<D> j, D d, |
| TFromD<D>* HWY_RESTRICT unaligned) { |
| HWY_LANES_CONSTEXPR size_t N = Lanes(d); |
| const auto j1_j0_i1_i0 = detail::Shuffle128<_MM_PERM_BABA>(i, j); |
| const auto j3_j2_i3_i2 = detail::Shuffle128<_MM_PERM_DCDC>(i, j); |
| const auto j1_i1_j0_i0 = |
| detail::Shuffle128<_MM_PERM_DBCA>(j1_j0_i1_i0, j1_j0_i1_i0); |
| const auto j3_i3_j2_i2 = |
| detail::Shuffle128<_MM_PERM_DBCA>(j3_j2_i3_i2, j3_j2_i3_i2); |
| StoreU(j1_i1_j0_i0, d, unaligned + 0 * N); |
| StoreU(j3_i3_j2_i2, d, unaligned + 1 * N); |
| } |
| |
| // Input (128-bit blocks): |
| // 9 6 3 0 (LSB of i) |
| // a 7 4 1 |
| // b 8 5 2 |
| // Output: |
| // 3 2 1 0 |
| // 7 6 5 4 |
| // b a 9 8 |
| template <class D, HWY_IF_V_SIZE_D(D, 64)> |
| HWY_API void StoreTransposedBlocks3(const VFromD<D> i, const VFromD<D> j, |
| const VFromD<D> k, D d, |
| TFromD<D>* HWY_RESTRICT unaligned) { |
| HWY_LANES_CONSTEXPR size_t N = Lanes(d); |
| const VFromD<D> j2_j0_i2_i0 = detail::Shuffle128<_MM_PERM_CACA>(i, j); |
| const VFromD<D> i3_i1_k2_k0 = detail::Shuffle128<_MM_PERM_DBCA>(k, i); |
| const VFromD<D> j3_j1_k3_k1 = detail::Shuffle128<_MM_PERM_DBDB>(k, j); |
| |
| const VFromD<D> out0 = // i1 k0 j0 i0 |
| detail::Shuffle128<_MM_PERM_CACA>(j2_j0_i2_i0, i3_i1_k2_k0); |
| const VFromD<D> out1 = // j2 i2 k1 j1 |
| detail::Shuffle128<_MM_PERM_DBAC>(j3_j1_k3_k1, j2_j0_i2_i0); |
| const VFromD<D> out2 = // k3 j3 i3 k2 |
| detail::Shuffle128<_MM_PERM_BDDB>(i3_i1_k2_k0, j3_j1_k3_k1); |
| |
| StoreU(out0, d, unaligned + 0 * N); |
| StoreU(out1, d, unaligned + 1 * N); |
| StoreU(out2, d, unaligned + 2 * N); |
| } |
| |
| // Input (128-bit blocks): |
| // c 8 4 0 (LSB of i) |
| // d 9 5 1 |
| // e a 6 2 |
| // f b 7 3 |
| // Output: |
| // 3 2 1 0 |
| // 7 6 5 4 |
| // b a 9 8 |
| // f e d c |
| template <class D, HWY_IF_V_SIZE_D(D, 64)> |
| HWY_API void StoreTransposedBlocks4(const VFromD<D> i, const VFromD<D> j, |
| const VFromD<D> k, const VFromD<D> l, D d, |
| TFromD<D>* HWY_RESTRICT unaligned) { |
| HWY_LANES_CONSTEXPR size_t N = Lanes(d); |
| const VFromD<D> j1_j0_i1_i0 = detail::Shuffle128<_MM_PERM_BABA>(i, j); |
| const VFromD<D> l1_l0_k1_k0 = detail::Shuffle128<_MM_PERM_BABA>(k, l); |
| const VFromD<D> j3_j2_i3_i2 = detail::Shuffle128<_MM_PERM_DCDC>(i, j); |
| const VFromD<D> l3_l2_k3_k2 = detail::Shuffle128<_MM_PERM_DCDC>(k, l); |
| const VFromD<D> out0 = |
| detail::Shuffle128<_MM_PERM_CACA>(j1_j0_i1_i0, l1_l0_k1_k0); |
| const VFromD<D> out1 = |
| detail::Shuffle128<_MM_PERM_DBDB>(j1_j0_i1_i0, l1_l0_k1_k0); |
| const VFromD<D> out2 = |
| detail::Shuffle128<_MM_PERM_CACA>(j3_j2_i3_i2, l3_l2_k3_k2); |
| const VFromD<D> out3 = |
| detail::Shuffle128<_MM_PERM_DBDB>(j3_j2_i3_i2, l3_l2_k3_k2); |
| StoreU(out0, d, unaligned + 0 * N); |
| StoreU(out1, d, unaligned + 1 * N); |
| StoreU(out2, d, unaligned + 2 * N); |
| StoreU(out3, d, unaligned + 3 * N); |
| } |
| |
| } // namespace detail |
| |
| // ------------------------------ Additional mask logical operations |
| |
| template <class T> |
| HWY_API Mask512<T> SetAtOrAfterFirst(Mask512<T> mask) { |
| return Mask512<T>{ |
| static_cast<typename Mask512<T>::Raw>(0u - detail::AVX3Blsi(mask.raw))}; |
| } |
| template <class T> |
| HWY_API Mask512<T> SetBeforeFirst(Mask512<T> mask) { |
| return Mask512<T>{ |
| static_cast<typename Mask512<T>::Raw>(detail::AVX3Blsi(mask.raw) - 1u)}; |
| } |
| template <class T> |
| HWY_API Mask512<T> SetAtOrBeforeFirst(Mask512<T> mask) { |
| return Mask512<T>{ |
| static_cast<typename Mask512<T>::Raw>(detail::AVX3Blsmsk(mask.raw))}; |
| } |
| template <class T> |
| HWY_API Mask512<T> SetOnlyFirst(Mask512<T> mask) { |
| return Mask512<T>{ |
| static_cast<typename Mask512<T>::Raw>(detail::AVX3Blsi(mask.raw))}; |
| } |
| |
| // ------------------------------ Shl (Dup128VecFromValues) |
| |
| HWY_API Vec512<uint16_t> operator<<(Vec512<uint16_t> v, Vec512<uint16_t> bits) { |
| return Vec512<uint16_t>{_mm512_sllv_epi16(v.raw, bits.raw)}; |
| } |
| |
| // 8-bit: may use the << overload for uint16_t. |
| HWY_API Vec512<uint8_t> operator<<(Vec512<uint8_t> v, Vec512<uint8_t> bits) { |
| const DFromV<decltype(v)> d; |
| #if HWY_TARGET <= HWY_AVX3_DL |
| // kMask[i] = 0xFF >> i |
| const VFromD<decltype(d)> masks = |
| Dup128VecFromValues(d, 0xFF, 0x7F, 0x3F, 0x1F, 0x0F, 0x07, 0x03, 0x01, 0, |
| 0, 0, 0, 0, 0, 0, 0); |
| // kShl[i] = 1 << i |
| const VFromD<decltype(d)> shl = |
| Dup128VecFromValues(d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0, |
| 0, 0, 0, 0, 0, 0, 0); |
| v = And(v, TableLookupBytes(masks, bits)); |
| const VFromD<decltype(d)> mul = TableLookupBytes(shl, bits); |
| return VFromD<decltype(d)>{_mm512_gf2p8mul_epi8(v.raw, mul.raw)}; |
| #else |
| const Repartition<uint16_t, decltype(d)> dw; |
| using VW = VFromD<decltype(dw)>; |
| const VW even_mask = Set(dw, 0x00FF); |
| const VW odd_mask = Set(dw, 0xFF00); |
| const VW vw = BitCast(dw, v); |
| const VW bits16 = BitCast(dw, bits); |
| // Shift even lanes in-place |
| const VW evens = vw << And(bits16, even_mask); |
| const VW odds = And(vw, odd_mask) << ShiftRight<8>(bits16); |
| return OddEven(BitCast(d, odds), BitCast(d, evens)); |
| #endif |
| } |
| |
| HWY_API Vec512<uint32_t> operator<<(const Vec512<uint32_t> v, |
| const Vec512<uint32_t> bits) { |
| return Vec512<uint32_t>{_mm512_sllv_epi32(v.raw, bits.raw)}; |
| } |
| |
| HWY_API Vec512<uint64_t> operator<<(const Vec512<uint64_t> v, |
| const Vec512<uint64_t> bits) { |
| return Vec512<uint64_t>{_mm512_sllv_epi64(v.raw, bits.raw)}; |
| } |
| |
| // Signed left shift is the same as unsigned. |
| template <typename T, HWY_IF_SIGNED(T)> |
| HWY_API Vec512<T> operator<<(const Vec512<T> v, const Vec512<T> bits) { |
| const DFromV<decltype(v)> di; |
| const RebindToUnsigned<decltype(di)> du; |
| return BitCast(di, BitCast(du, v) << BitCast(du, bits)); |
| } |
| |
| // ------------------------------ Shr (IfVecThenElse) |
| |
| HWY_API Vec512<uint16_t> operator>>(const Vec512<uint16_t> v, |
| const Vec512<uint16_t> bits) { |
| return Vec512<uint16_t>{_mm512_srlv_epi16(v.raw, bits.raw)}; |
| } |
| |
| // 8-bit uses 16-bit shifts. |
| HWY_API Vec512<uint8_t> operator>>(Vec512<uint8_t> v, Vec512<uint8_t> bits) { |
| const DFromV<decltype(v)> d; |
| const RepartitionToWide<decltype(d)> dw; |
| using VW = VFromD<decltype(dw)>; |
| const VW mask = Set(dw, 0x00FF); |
| const VW vw = BitCast(dw, v); |
| const VW bits16 = BitCast(dw, bits); |
| const VW evens = And(vw, mask) >> And(bits16, mask); |
| // Shift odd lanes in-place |
| const VW odds = vw >> ShiftRight<8>(bits16); |
| return OddEven(BitCast(d, odds), BitCast(d, evens)); |
| } |
| |
| HWY_API Vec512<uint32_t> operator>>(const Vec512<uint32_t> v, |
| const Vec512<uint32_t> bits) { |
| return Vec512<uint32_t>{_mm512_srlv_epi32(v.raw, bits.raw)}; |
| } |
| |
| HWY_API Vec512<uint64_t> operator>>(const Vec512<uint64_t> v, |
| const Vec512<uint64_t> bits) { |
| return Vec512<uint64_t>{_mm512_srlv_epi64(v.raw, bits.raw)}; |
| } |
| |
| HWY_API Vec512<int16_t> operator>>(const Vec512<int16_t> v, |
| const Vec512<int16_t> bits) { |
| return Vec512<int16_t>{_mm512_srav_epi16(v.raw, bits.raw)}; |
| } |
| |
| // 8-bit uses 16-bit shifts. |
| HWY_API Vec512<int8_t> operator>>(Vec512<int8_t> v, Vec512<int8_t> bits) { |
| const DFromV<decltype(v)> d; |
| const RepartitionToWide<decltype(d)> dw; |
| const RebindToUnsigned<decltype(dw)> dw_u; |
| using VW = VFromD<decltype(dw)>; |
| const VW mask = Set(dw, 0x00FF); |
| const VW vw = BitCast(dw, v); |
| const VW bits16 = BitCast(dw, bits); |
| const VW evens = ShiftRight<8>(ShiftLeft<8>(vw)) >> And(bits16, mask); |
| // Shift odd lanes in-place |
| const VW odds = vw >> BitCast(dw, ShiftRight<8>(BitCast(dw_u, bits16))); |
| return OddEven(BitCast(d, odds), BitCast(d, evens)); |
| } |
| |
| HWY_API Vec512<int32_t> operator>>(const Vec512<int32_t> v, |
| const Vec512<int32_t> bits) { |
| return Vec512<int32_t>{_mm512_srav_epi32(v.raw, bits.raw)}; |
| } |
| |
| HWY_API Vec512<int64_t> operator>>(const Vec512<int64_t> v, |
| const Vec512<int64_t> bits) { |
| return Vec512<int64_t>{_mm512_srav_epi64(v.raw, bits.raw)}; |
| } |
| |
| // ------------------------------ WidenMulPairwiseAdd |
| |
| #if HWY_NATIVE_DOT_BF16 |
| template <class DF, HWY_IF_F32_D(DF), HWY_IF_V_SIZE_D(DF, 64), |
| class VBF = VFromD<Repartition<bfloat16_t, DF>>> |
| HWY_API VFromD<DF> WidenMulPairwiseAdd(DF df, VBF a, VBF b) { |
| return VFromD<DF>{_mm512_dpbf16_ps(Zero(df).raw, |
| reinterpret_cast<__m512bh>(a.raw), |
| reinterpret_cast<__m512bh>(b.raw))}; |
| } |
| #endif // HWY_NATIVE_DOT_BF16 |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_I32_D(D)> |
| HWY_API VFromD<D> WidenMulPairwiseAdd(D /*d32*/, Vec512<int16_t> a, |
| Vec512<int16_t> b) { |
| return VFromD<D>{_mm512_madd_epi16(a.raw, b.raw)}; |
| } |
| |
| // ------------------------------ SatWidenMulPairwiseAdd |
| template <class DI16, HWY_IF_V_SIZE_D(DI16, 64), HWY_IF_I16_D(DI16)> |
| HWY_API VFromD<DI16> SatWidenMulPairwiseAdd( |
| DI16 /* tag */, VFromD<Repartition<uint8_t, DI16>> a, |
| VFromD<Repartition<int8_t, DI16>> b) { |
| return VFromD<DI16>{_mm512_maddubs_epi16(a.raw, b.raw)}; |
| } |
| |
| // ------------------------------ SatWidenMulPairwiseAccumulate |
| #if HWY_TARGET <= HWY_AVX3_DL |
| template <class DI32, HWY_IF_I32_D(DI32), HWY_IF_V_SIZE_D(DI32, 64)> |
| HWY_API VFromD<DI32> SatWidenMulPairwiseAccumulate( |
| DI32 /* tag */, VFromD<Repartition<int16_t, DI32>> a, |
| VFromD<Repartition<int16_t, DI32>> b, VFromD<DI32> sum) { |
| return VFromD<DI32>{_mm512_dpwssds_epi32(sum.raw, a.raw, b.raw)}; |
| } |
| #endif // HWY_TARGET <= HWY_AVX3_DL |
| |
| // ------------------------------ ReorderWidenMulAccumulate |
| |
| #if HWY_NATIVE_DOT_BF16 |
| template <class DF, HWY_IF_F32_D(DF), HWY_IF_V_SIZE_D(DF, 64), |
| class VBF = VFromD<Repartition<bfloat16_t, DF>>> |
| HWY_API VFromD<DF> ReorderWidenMulAccumulate(DF /*df*/, VBF a, VBF b, |
| const VFromD<DF> sum0, |
| VFromD<DF>& /*sum1*/) { |
| return VFromD<DF>{_mm512_dpbf16_ps(sum0.raw, |
| reinterpret_cast<__m512bh>(a.raw), |
| reinterpret_cast<__m512bh>(b.raw))}; |
| } |
| #endif // HWY_NATIVE_DOT_BF16 |
| |
| template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_I32_D(D)> |
| HWY_API VFromD<D> ReorderWidenMulAccumulate(D d, Vec512<int16_t> a, |
| Vec512<int16_t> b, |
| const VFromD<D> sum0, |
| VFromD<D>& /*sum1*/) { |
| (void)d; |
| #if HWY_TARGET <= HWY_AVX3_DL |
| return VFromD<D>{_mm512_dpwssd_epi32(sum0.raw, a.raw, b.raw)}; |
| #else |
| return sum0 + WidenMulPairwiseAdd(d, a, b); |
| #endif |
| } |
| |
| HWY_API Vec512<int32_t> RearrangeToOddPlusEven(const Vec512<int32_t> sum0, |
| Vec512<int32_t> /*sum1*/) { |
| return sum0; // invariant already holds |
| } |
| |
| HWY_API Vec512<uint32_t> RearrangeToOddPlusEven(const Vec512<uint32_t> sum0, |
| Vec512<uint32_t> /*sum1*/) { |
| return sum0; // invariant already holds |
| } |
| |
| // ------------------------------ SumOfMulQuadAccumulate |
| |
| #if HWY_TARGET <= HWY_AVX3_DL |
| |
| template <class DI32, HWY_IF_V_SIZE_D(DI32, 64)> |
| HWY_API VFromD<DI32> SumOfMulQuadAccumulate( |
| DI32 /*di32*/, VFromD<Repartition<uint8_t, DI32>> a_u, |
| VFromD<Repartition<int8_t, DI32>> b_i, VFromD<DI32> sum) { |
| return VFromD<DI32>{_mm512_dpbusd_epi32(sum.raw, a_u.raw, b_i.raw)}; |
| } |
| |
| #endif |
| |
| // ------------------------------ Reductions |
| |
| namespace detail { |
| |
| // Used by generic_ops-inl |
| template <class D, class Func, HWY_IF_V_SIZE_D(D, 64)> |
| HWY_INLINE VFromD<D> ReduceAcrossBlocks(D d, Func f, VFromD<D> v) { |
| v = f(v, SwapAdjacentBlocks(v)); |
| return f(v, ReverseBlocks(d, v)); |
| } |
| |
| } // namespace detail |
| |
| // ------------------------------ BitShuffle |
| #if HWY_TARGET <= HWY_AVX3_DL |
| template <class V, class VI, HWY_IF_UI64(TFromV<V>), HWY_IF_UI8(TFromV<VI>), |
| HWY_IF_V_SIZE_V(V, 64), HWY_IF_V_SIZE_V(VI, 64)> |
| HWY_API V BitShuffle(V v, VI idx) { |
| const DFromV<decltype(v)> d64; |
| const RebindToUnsigned<decltype(d64)> du64; |
| const Rebind<uint8_t, decltype(d64)> du8; |
| |
| const __mmask64 mmask64_bit_shuf_result = |
| _mm512_bitshuffle_epi64_mask(v.raw, idx.raw); |
| |
| #if HWY_ARCH_X86_64 |
| const VFromD<decltype(du8)> vu8_bit_shuf_result{ |
| _mm_cvtsi64_si128(static_cast<int64_t>(mmask64_bit_shuf_result))}; |
| #else |
| const int32_t i32_lo_bit_shuf_result = |
| static_cast<int32_t>(mmask64_bit_shuf_result); |
| const int32_t i32_hi_bit_shuf_result = |
| static_cast<int32_t>(_kshiftri_mask64(mmask64_bit_shuf_result, 32)); |
| |
| const VFromD<decltype(du8)> vu8_bit_shuf_result = ResizeBitCast( |
| du8, InterleaveLower( |
| Vec128<uint32_t>{_mm_cvtsi32_si128(i32_lo_bit_shuf_result)}, |
| Vec128<uint32_t>{_mm_cvtsi32_si128(i32_hi_bit_shuf_result)})); |
| #endif |
| |
| return BitCast(d64, PromoteTo(du64, vu8_bit_shuf_result)); |
| } |
| #endif // HWY_TARGET <= HWY_AVX3_DL |
| |
| // ------------------------------ MultiRotateRight |
| |
| #if HWY_TARGET <= HWY_AVX3_DL |
| |
| #ifdef HWY_NATIVE_MULTIROTATERIGHT |
| #undef HWY_NATIVE_MULTIROTATERIGHT |
| #else |
| #define HWY_NATIVE_MULTIROTATERIGHT |
| #endif |
| |
| template <class V, class VI, HWY_IF_UI64(TFromV<V>), HWY_IF_UI8(TFromV<VI>), |
| HWY_IF_V_SIZE_V(V, 64), HWY_IF_V_SIZE_V(VI, HWY_MAX_LANES_V(V) * 8)> |
| HWY_API V MultiRotateRight(V v, VI idx) { |
| return V{_mm512_multishift_epi64_epi8(idx.raw, v.raw)}; |
| } |
| |
| #endif |
| |
| // -------------------- LeadingZeroCount |
| |
| template <class V, HWY_IF_UI32(TFromV<V>), HWY_IF_V_SIZE_V(V, 64)> |
| HWY_API V LeadingZeroCount(V v) { |
| return V{_mm512_lzcnt_epi32(v.raw)}; |
| } |
| |
| template <class V, HWY_IF_UI64(TFromV<V>), HWY_IF_V_SIZE_V(V, 64)> |
| HWY_API V LeadingZeroCount(V v) { |
| return V{_mm512_lzcnt_epi64(v.raw)}; |
| } |
| |
| // NOLINTNEXTLINE(google-readability-namespace-comments) |
| } // namespace HWY_NAMESPACE |
| } // namespace hwy |
| HWY_AFTER_NAMESPACE(); |
| |
| // Note that the GCC warnings are not suppressed if we only wrap the *intrin.h - |
| // the warning seems to be issued at the call site of intrinsics, i.e. our code. |
| HWY_DIAGNOSTICS(pop) |