| // Copyright 2021 Google LLC |
| // SPDX-License-Identifier: Apache-2.0 |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| // Arm SVE[2] vectors (length not known at compile time). |
| // External include guard in highway.h - see comment there. |
| |
| #include <arm_sve.h> |
| |
| #include "third_party/highway/hwy/ops/shared-inl.h" |
| |
| // Arm C215 declares that SVE vector lengths will always be a power of two. |
| // We default to relying on this, which makes some operations more efficient. |
| // You can still opt into fixups by setting this to 0 (unsupported). |
| #ifndef HWY_SVE_IS_POW2 |
| #define HWY_SVE_IS_POW2 1 |
| #endif |
| |
| #if HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128 |
| #define HWY_SVE_HAVE_2 1 |
| #else |
| #define HWY_SVE_HAVE_2 0 |
| #endif |
| |
| // If 1, both __bf16 and a limited set of *_bf16 SVE intrinsics are available: |
| // create/get/set/dup, ld/st, sel, rev, trn, uzp, zip. |
| #if HWY_ARM_HAVE_SCALAR_BF16_TYPE && defined(__ARM_FEATURE_SVE_BF16) |
| #define HWY_SVE_HAVE_BF16_FEATURE 1 |
| #else |
| #define HWY_SVE_HAVE_BF16_FEATURE 0 |
| #endif |
| |
| // HWY_SVE_HAVE_BF16_VEC is defined to 1 if the SVE svbfloat16_t vector type |
| // is supported, even if HWY_SVE_HAVE_BF16_FEATURE (= intrinsics) is 0. |
| #if HWY_SVE_HAVE_BF16_FEATURE || \ |
| (HWY_COMPILER_CLANG >= 1200 && defined(__ARM_FEATURE_SVE_BF16)) || \ |
| HWY_COMPILER_GCC_ACTUAL >= 1000 |
| #define HWY_SVE_HAVE_BF16_VEC 1 |
| #else |
| #define HWY_SVE_HAVE_BF16_VEC 0 |
| #endif |
| |
| // HWY_SVE_HAVE_F32_TO_BF16C is defined to 1 if the SVE svcvt_bf16_f32_x |
| // and svcvtnt_bf16_f32_x intrinsics are available, even if the __bf16 type |
| // is disabled |
| #if HWY_SVE_HAVE_BF16_VEC && defined(__ARM_FEATURE_SVE_BF16) |
| #define HWY_SVE_HAVE_F32_TO_BF16C 1 |
| #else |
| #define HWY_SVE_HAVE_F32_TO_BF16C 0 |
| #endif |
| |
| HWY_BEFORE_NAMESPACE(); |
| namespace hwy { |
| namespace HWY_NAMESPACE { |
| |
| template <class V> |
| struct DFromV_t {}; // specialized in macros |
| template <class V> |
| using DFromV = typename DFromV_t<RemoveConst<V>>::type; |
| |
| template <class V> |
| using TFromV = TFromD<DFromV<V>>; |
| |
| // ================================================== MACROS |
| |
| // Generate specializations and function definitions using X macros. Although |
| // harder to read and debug, writing everything manually is too bulky. |
| |
| namespace detail { // for code folding |
| |
| // Args: BASE, CHAR, BITS, HALF, NAME, OP |
| |
| // Unsigned: |
| #define HWY_SVE_FOREACH_U08(X_MACRO, NAME, OP) X_MACRO(uint, u, 8, 8, NAME, OP) |
| #define HWY_SVE_FOREACH_U16(X_MACRO, NAME, OP) X_MACRO(uint, u, 16, 8, NAME, OP) |
| #define HWY_SVE_FOREACH_U32(X_MACRO, NAME, OP) \ |
| X_MACRO(uint, u, 32, 16, NAME, OP) |
| #define HWY_SVE_FOREACH_U64(X_MACRO, NAME, OP) \ |
| X_MACRO(uint, u, 64, 32, NAME, OP) |
| |
| // Signed: |
| #define HWY_SVE_FOREACH_I08(X_MACRO, NAME, OP) X_MACRO(int, s, 8, 8, NAME, OP) |
| #define HWY_SVE_FOREACH_I16(X_MACRO, NAME, OP) X_MACRO(int, s, 16, 8, NAME, OP) |
| #define HWY_SVE_FOREACH_I32(X_MACRO, NAME, OP) X_MACRO(int, s, 32, 16, NAME, OP) |
| #define HWY_SVE_FOREACH_I64(X_MACRO, NAME, OP) X_MACRO(int, s, 64, 32, NAME, OP) |
| |
| // Float: |
| #define HWY_SVE_FOREACH_F16(X_MACRO, NAME, OP) \ |
| X_MACRO(float, f, 16, 16, NAME, OP) |
| #define HWY_SVE_FOREACH_F32(X_MACRO, NAME, OP) \ |
| X_MACRO(float, f, 32, 16, NAME, OP) |
| #define HWY_SVE_FOREACH_F64(X_MACRO, NAME, OP) \ |
| X_MACRO(float, f, 64, 32, NAME, OP) |
| |
| #define HWY_SVE_FOREACH_BF16_UNCONDITIONAL(X_MACRO, NAME, OP) \ |
| X_MACRO(bfloat, bf, 16, 16, NAME, OP) |
| |
| #if HWY_SVE_HAVE_BF16_FEATURE |
| #define HWY_SVE_FOREACH_BF16(X_MACRO, NAME, OP) \ |
| HWY_SVE_FOREACH_BF16_UNCONDITIONAL(X_MACRO, NAME, OP) |
| // We have both f16 and bf16, so nothing is emulated. |
| |
| // NOTE: hwy::EnableIf<!hwy::IsSame<D, D>()>* = nullptr is used instead of |
| // hwy::EnableIf<false>* = nullptr to avoid compiler errors since |
| // !hwy::IsSame<D, D>() is always false and as !hwy::IsSame<D, D>() will cause |
| // SFINAE to occur instead of a hard error due to a dependency on the D template |
| // argument |
| #define HWY_SVE_IF_EMULATED_D(D) hwy::EnableIf<!hwy::IsSame<D, D>()>* = nullptr |
| #define HWY_GENERIC_IF_EMULATED_D(D) \ |
| hwy::EnableIf<!hwy::IsSame<D, D>()>* = nullptr |
| #define HWY_SVE_IF_NOT_EMULATED_D(D) hwy::EnableIf<true>* = nullptr |
| #else |
| #define HWY_SVE_FOREACH_BF16(X_MACRO, NAME, OP) |
| #define HWY_SVE_IF_EMULATED_D(D) HWY_IF_BF16_D(D) |
| #define HWY_GENERIC_IF_EMULATED_D(D) HWY_IF_BF16_D(D) |
| #define HWY_SVE_IF_NOT_EMULATED_D(D) HWY_IF_NOT_BF16_D(D) |
| #endif // HWY_SVE_HAVE_BF16_FEATURE |
| |
| // For all element sizes: |
| #define HWY_SVE_FOREACH_U(X_MACRO, NAME, OP) \ |
| HWY_SVE_FOREACH_U08(X_MACRO, NAME, OP) \ |
| HWY_SVE_FOREACH_U16(X_MACRO, NAME, OP) \ |
| HWY_SVE_FOREACH_U32(X_MACRO, NAME, OP) \ |
| HWY_SVE_FOREACH_U64(X_MACRO, NAME, OP) |
| |
| #define HWY_SVE_FOREACH_I(X_MACRO, NAME, OP) \ |
| HWY_SVE_FOREACH_I08(X_MACRO, NAME, OP) \ |
| HWY_SVE_FOREACH_I16(X_MACRO, NAME, OP) \ |
| HWY_SVE_FOREACH_I32(X_MACRO, NAME, OP) \ |
| HWY_SVE_FOREACH_I64(X_MACRO, NAME, OP) |
| |
| #define HWY_SVE_FOREACH_F3264(X_MACRO, NAME, OP) \ |
| HWY_SVE_FOREACH_F32(X_MACRO, NAME, OP) \ |
| HWY_SVE_FOREACH_F64(X_MACRO, NAME, OP) |
| |
| // HWY_SVE_FOREACH_F does not include HWY_SVE_FOREACH_BF16 because SVE lacks |
| // bf16 overloads for some intrinsics (especially less-common arithmetic). |
| // However, this does include f16 because SVE supports it unconditionally. |
| #define HWY_SVE_FOREACH_F(X_MACRO, NAME, OP) \ |
| HWY_SVE_FOREACH_F16(X_MACRO, NAME, OP) \ |
| HWY_SVE_FOREACH_F3264(X_MACRO, NAME, OP) |
| |
| // Commonly used type categories for a given element size: |
| #define HWY_SVE_FOREACH_UI08(X_MACRO, NAME, OP) \ |
| HWY_SVE_FOREACH_U08(X_MACRO, NAME, OP) \ |
| HWY_SVE_FOREACH_I08(X_MACRO, NAME, OP) |
| |
| #define HWY_SVE_FOREACH_UI16(X_MACRO, NAME, OP) \ |
| HWY_SVE_FOREACH_U16(X_MACRO, NAME, OP) \ |
| HWY_SVE_FOREACH_I16(X_MACRO, NAME, OP) |
| |
| #define HWY_SVE_FOREACH_UI32(X_MACRO, NAME, OP) \ |
| HWY_SVE_FOREACH_U32(X_MACRO, NAME, OP) \ |
| HWY_SVE_FOREACH_I32(X_MACRO, NAME, OP) |
| |
| #define HWY_SVE_FOREACH_UI64(X_MACRO, NAME, OP) \ |
| HWY_SVE_FOREACH_U64(X_MACRO, NAME, OP) \ |
| HWY_SVE_FOREACH_I64(X_MACRO, NAME, OP) |
| |
| #define HWY_SVE_FOREACH_UIF3264(X_MACRO, NAME, OP) \ |
| HWY_SVE_FOREACH_UI32(X_MACRO, NAME, OP) \ |
| HWY_SVE_FOREACH_UI64(X_MACRO, NAME, OP) \ |
| HWY_SVE_FOREACH_F3264(X_MACRO, NAME, OP) |
| |
| // Commonly used type categories: |
| #define HWY_SVE_FOREACH_UI(X_MACRO, NAME, OP) \ |
| HWY_SVE_FOREACH_U(X_MACRO, NAME, OP) \ |
| HWY_SVE_FOREACH_I(X_MACRO, NAME, OP) |
| |
| #define HWY_SVE_FOREACH_IF(X_MACRO, NAME, OP) \ |
| HWY_SVE_FOREACH_I(X_MACRO, NAME, OP) \ |
| HWY_SVE_FOREACH_F(X_MACRO, NAME, OP) |
| |
| #define HWY_SVE_FOREACH(X_MACRO, NAME, OP) \ |
| HWY_SVE_FOREACH_U(X_MACRO, NAME, OP) \ |
| HWY_SVE_FOREACH_I(X_MACRO, NAME, OP) \ |
| HWY_SVE_FOREACH_F(X_MACRO, NAME, OP) |
| |
| // Assemble types for use in x-macros |
| #define HWY_SVE_T(BASE, BITS) BASE##BITS##_t |
| #define HWY_SVE_D(BASE, BITS, N, POW2) Simd<HWY_SVE_T(BASE, BITS), N, POW2> |
| #define HWY_SVE_V(BASE, BITS) sv##BASE##BITS##_t |
| #define HWY_SVE_TUPLE(BASE, BITS, MUL) sv##BASE##BITS##x##MUL##_t |
| |
| } // namespace detail |
| |
| #define HWY_SPECIALIZE(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| template <> \ |
| struct DFromV_t<HWY_SVE_V(BASE, BITS)> { \ |
| using type = ScalableTag<HWY_SVE_T(BASE, BITS)>; \ |
| }; |
| |
| HWY_SVE_FOREACH(HWY_SPECIALIZE, _, _) |
| #if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC |
| HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SPECIALIZE, _, _) |
| #endif |
| #undef HWY_SPECIALIZE |
| |
| // Note: _x (don't-care value for inactive lanes) avoids additional MOVPRFX |
| // instructions, and we anyway only use it when the predicate is ptrue. |
| |
| // vector = f(vector), e.g. Not |
| #define HWY_SVE_RETV_ARGPV(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ |
| return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ |
| } |
| #define HWY_SVE_RETV_ARGV(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ |
| return sv##OP##_##CHAR##BITS(v); \ |
| } |
| #define HWY_SVE_RETV_ARGMV_M(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| NAME(HWY_SVE_V(BASE, BITS) no, svbool_t m, HWY_SVE_V(BASE, BITS) a) { \ |
| return sv##OP##_##CHAR##BITS##_m(no, m, a); \ |
| } |
| #define HWY_SVE_RETV_ARGMV(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API HWY_SVE_V(BASE, BITS) NAME(svbool_t m, HWY_SVE_V(BASE, BITS) v) { \ |
| return sv##OP##_##CHAR##BITS##_x(m, v); \ |
| } |
| #define HWY_SVE_RETV_ARGMV_Z(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API HWY_SVE_V(BASE, BITS) NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a) { \ |
| return sv##OP##_##CHAR##BITS##_z(m, a); \ |
| } |
| |
| // vector = f(vector, scalar), e.g. detail::AddN |
| #define HWY_SVE_RETV_ARGPVN(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \ |
| return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), a, b); \ |
| } |
| #define HWY_SVE_RETV_ARGVN(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \ |
| return sv##OP##_##CHAR##BITS(a, b); \ |
| } |
| |
| // vector = f(vector, vector), e.g. Add |
| #define HWY_SVE_RETV_ARGVV(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ |
| return sv##OP##_##CHAR##BITS(a, b); \ |
| } |
| // All-true mask |
| #define HWY_SVE_RETV_ARGPVV(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ |
| return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), a, b); \ |
| } |
| // User-specified mask. Mask=false value is undefined and must be set by caller |
| // because SVE instructions take it from one of the two inputs, whereas |
| // AVX-512, RVV and Highway allow a third argument. |
| #define HWY_SVE_RETV_ARGMVV(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ |
| return sv##OP##_##CHAR##BITS##_x(m, a, b); \ |
| } |
| // User-specified mask. Mask=false value is zero. |
| #define HWY_SVE_RETV_ARGMVV_Z(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ |
| return sv##OP##_##CHAR##BITS##_z(m, a, b); \ |
| } |
| |
| #define HWY_SVE_RETV_ARGVVV(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b, \ |
| HWY_SVE_V(BASE, BITS) c) { \ |
| return sv##OP##_##CHAR##BITS(a, b, c); \ |
| } |
| #define HWY_SVE_RETV_ARGMVVV(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b, \ |
| HWY_SVE_V(BASE, BITS) c) { \ |
| return sv##OP##_##CHAR##BITS##_x(m, a, b, c); \ |
| } |
| #define HWY_SVE_RETV_ARGMVVV_Z(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| NAME(svbool_t m, HWY_SVE_V(BASE, BITS) mul, HWY_SVE_V(BASE, BITS) x, \ |
| HWY_SVE_V(BASE, BITS) add) { \ |
| return sv##OP##_##CHAR##BITS##_z(m, x, mul, add); \ |
| } |
| |
| // ------------------------------ Lanes |
| |
| namespace detail { |
| |
| // Returns actual lanes of a hardware vector without rounding to a power of two. |
| template <typename T, HWY_IF_T_SIZE(T, 1)> |
| HWY_INLINE size_t AllHardwareLanes() { |
| return svcntb_pat(SV_ALL); |
| } |
| template <typename T, HWY_IF_T_SIZE(T, 2)> |
| HWY_INLINE size_t AllHardwareLanes() { |
| return svcnth_pat(SV_ALL); |
| } |
| template <typename T, HWY_IF_T_SIZE(T, 4)> |
| HWY_INLINE size_t AllHardwareLanes() { |
| return svcntw_pat(SV_ALL); |
| } |
| template <typename T, HWY_IF_T_SIZE(T, 8)> |
| HWY_INLINE size_t AllHardwareLanes() { |
| return svcntd_pat(SV_ALL); |
| } |
| |
| // All-true mask from a macro |
| |
| #if HWY_SVE_IS_POW2 |
| #define HWY_SVE_ALL_PTRUE(BITS) svptrue_b##BITS() |
| #define HWY_SVE_PTRUE(BITS) svptrue_b##BITS() |
| #else |
| #define HWY_SVE_ALL_PTRUE(BITS) svptrue_pat_b##BITS(SV_ALL) |
| #define HWY_SVE_PTRUE(BITS) svptrue_pat_b##BITS(SV_POW2) |
| #endif // HWY_SVE_IS_POW2 |
| |
| } // namespace detail |
| |
| #if HWY_HAVE_SCALABLE |
| |
| // Returns actual number of lanes after capping by N and shifting. May return 0 |
| // (e.g. for "1/8th" of a u32x4 - would be 1 for 1/8th of u32x8). |
| template <typename T, size_t N, int kPow2> |
| HWY_API size_t Lanes(Simd<T, N, kPow2> d) { |
| const size_t actual = detail::AllHardwareLanes<T>(); |
| constexpr size_t kMaxLanes = MaxLanes(d); |
| constexpr int kClampedPow2 = HWY_MIN(kPow2, 0); |
| // Common case of full vectors: avoid any extra instructions. |
| if (detail::IsFull(d)) return actual; |
| return HWY_MIN(detail::ScaleByPower(actual, kClampedPow2), kMaxLanes); |
| } |
| |
| #endif // HWY_HAVE_SCALABLE |
| |
| // ================================================== MASK INIT |
| |
| // One mask bit per byte; only the one belonging to the lowest byte is valid. |
| |
| // ------------------------------ FirstN |
| #define HWY_SVE_FIRSTN(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| template <size_t N, int kPow2> \ |
| HWY_API svbool_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, size_t count) { \ |
| const size_t limit = detail::IsFull(d) ? count : HWY_MIN(Lanes(d), count); \ |
| return sv##OP##_b##BITS##_u32(uint32_t{0}, static_cast<uint32_t>(limit)); \ |
| } |
| HWY_SVE_FOREACH(HWY_SVE_FIRSTN, FirstN, whilelt) |
| #if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC |
| HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_FIRSTN, FirstN, whilelt) |
| #endif |
| |
| template <class D, HWY_SVE_IF_EMULATED_D(D)> |
| svbool_t FirstN(D /* tag */, size_t count) { |
| return FirstN(RebindToUnsigned<D>(), count); |
| } |
| |
| #undef HWY_SVE_FIRSTN |
| |
| template <class D> |
| using MFromD = svbool_t; |
| |
| namespace detail { |
| |
| #define HWY_SVE_WRAP_PTRUE(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| template <size_t N, int kPow2> \ |
| HWY_API svbool_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */) { \ |
| return HWY_SVE_PTRUE(BITS); \ |
| } \ |
| template <size_t N, int kPow2> \ |
| HWY_API svbool_t All##NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */) { \ |
| return HWY_SVE_ALL_PTRUE(BITS); \ |
| } |
| |
| HWY_SVE_FOREACH(HWY_SVE_WRAP_PTRUE, PTrue, ptrue) // return all-true |
| HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_WRAP_PTRUE, PTrue, ptrue) |
| #undef HWY_SVE_WRAP_PTRUE |
| |
| HWY_API svbool_t PFalse() { return svpfalse_b(); } |
| |
| // Returns all-true if d is HWY_FULL or FirstN(N) after capping N. |
| // |
| // This is used in functions that load/store memory; other functions (e.g. |
| // arithmetic) can ignore d and use PTrue instead. |
| template <class D> |
| svbool_t MakeMask(D d) { |
| return IsFull(d) ? PTrue(d) : FirstN(d, Lanes(d)); |
| } |
| |
| } // namespace detail |
| |
| #ifdef HWY_NATIVE_MASK_FALSE |
| #undef HWY_NATIVE_MASK_FALSE |
| #else |
| #define HWY_NATIVE_MASK_FALSE |
| #endif |
| |
| template <class D> |
| HWY_API svbool_t MaskFalse(const D /*d*/) { |
| return detail::PFalse(); |
| } |
| |
| // ================================================== INIT |
| |
| // ------------------------------ Set |
| // vector = f(d, scalar), e.g. Set |
| #define HWY_SVE_SET(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| template <size_t N, int kPow2> \ |
| HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ |
| HWY_SVE_T(BASE, BITS) arg) { \ |
| return sv##OP##_##CHAR##BITS(arg); \ |
| } |
| |
| HWY_SVE_FOREACH(HWY_SVE_SET, Set, dup_n) |
| #if HWY_SVE_HAVE_BF16_FEATURE // for if-elif chain |
| HWY_SVE_FOREACH_BF16(HWY_SVE_SET, Set, dup_n) |
| #elif HWY_SVE_HAVE_BF16_VEC |
| // Required for Zero and VFromD |
| template <class D, HWY_IF_BF16_D(D)> |
| HWY_API svbfloat16_t Set(D d, bfloat16_t arg) { |
| return svreinterpret_bf16_u16( |
| Set(RebindToUnsigned<decltype(d)>(), BitCastScalar<uint16_t>(arg))); |
| } |
| #else // neither bf16 feature nor vector: emulate with u16 |
| // Required for Zero and VFromD |
| template <class D, HWY_IF_BF16_D(D)> |
| HWY_API svuint16_t Set(D d, bfloat16_t arg) { |
| const RebindToUnsigned<decltype(d)> du; |
| return Set(du, BitCastScalar<uint16_t>(arg)); |
| } |
| #endif // HWY_SVE_HAVE_BF16_FEATURE |
| #undef HWY_SVE_SET |
| |
| template <class D> |
| using VFromD = decltype(Set(D(), TFromD<D>())); |
| |
| using VBF16 = VFromD<ScalableTag<bfloat16_t>>; |
| |
| // ------------------------------ MaskedSetOr/MaskedSet |
| |
| #define HWY_SVE_MASKED_SET_OR(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| NAME(HWY_SVE_V(BASE, BITS) no, svbool_t m, HWY_SVE_T(BASE, BITS) op) { \ |
| return sv##OP##_##CHAR##BITS##_m(no, m, op); \ |
| } |
| |
| HWY_SVE_FOREACH(HWY_SVE_MASKED_SET_OR, MaskedSetOr, dup_n) |
| #undef HWY_SVE_MASKED_SET_OR |
| |
| #define HWY_SVE_MASKED_SET(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| template <size_t N, int kPow2> \ |
| HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ |
| svbool_t m, HWY_SVE_T(BASE, BITS) op) { \ |
| return sv##OP##_##CHAR##BITS##_z(m, op); \ |
| } |
| |
| HWY_SVE_FOREACH(HWY_SVE_MASKED_SET, MaskedSet, dup_n) |
| #undef HWY_SVE_MASKED_SET |
| |
| // ------------------------------ Zero |
| |
| template <class D> |
| VFromD<D> Zero(D d) { |
| // Cast to support bfloat16_t. |
| const RebindToUnsigned<decltype(d)> du; |
| return BitCast(d, Set(du, 0)); |
| } |
| |
| // ------------------------------ BitCast |
| |
| namespace detail { |
| |
| // u8: no change |
| #define HWY_SVE_CAST_NOP(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API HWY_SVE_V(BASE, BITS) BitCastToByte(HWY_SVE_V(BASE, BITS) v) { \ |
| return v; \ |
| } \ |
| template <size_t N, int kPow2> \ |
| HWY_API HWY_SVE_V(BASE, BITS) BitCastFromByte( \ |
| HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, HWY_SVE_V(BASE, BITS) v) { \ |
| return v; \ |
| } |
| |
| // All other types |
| #define HWY_SVE_CAST(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_INLINE svuint8_t BitCastToByte(HWY_SVE_V(BASE, BITS) v) { \ |
| return sv##OP##_u8_##CHAR##BITS(v); \ |
| } \ |
| template <size_t N, int kPow2> \ |
| HWY_INLINE HWY_SVE_V(BASE, BITS) \ |
| BitCastFromByte(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, svuint8_t v) { \ |
| return sv##OP##_##CHAR##BITS##_u8(v); \ |
| } |
| |
| // U08 is special-cased, hence do not use FOREACH. |
| HWY_SVE_FOREACH_U08(HWY_SVE_CAST_NOP, _, _) |
| HWY_SVE_FOREACH_I08(HWY_SVE_CAST, _, reinterpret) |
| HWY_SVE_FOREACH_UI16(HWY_SVE_CAST, _, reinterpret) |
| HWY_SVE_FOREACH_UI32(HWY_SVE_CAST, _, reinterpret) |
| HWY_SVE_FOREACH_UI64(HWY_SVE_CAST, _, reinterpret) |
| HWY_SVE_FOREACH_F(HWY_SVE_CAST, _, reinterpret) |
| |
| #if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC |
| HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_CAST, _, reinterpret) |
| #else // !(HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC) |
| template <class V, HWY_SVE_IF_EMULATED_D(DFromV<V>)> |
| HWY_INLINE svuint8_t BitCastToByte(V v) { |
| const RebindToUnsigned<DFromV<V>> du; |
| return BitCastToByte(BitCast(du, v)); |
| } |
| |
| template <class D, HWY_SVE_IF_EMULATED_D(D)> |
| HWY_INLINE VFromD<D> BitCastFromByte(D d, svuint8_t v) { |
| const RebindToUnsigned<decltype(d)> du; |
| return BitCastFromByte(du, v); |
| } |
| #endif // HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC |
| |
| #undef HWY_SVE_CAST_NOP |
| #undef HWY_SVE_CAST |
| |
| } // namespace detail |
| |
| template <class D, class FromV> |
| HWY_API VFromD<D> BitCast(D d, FromV v) { |
| return detail::BitCastFromByte(d, detail::BitCastToByte(v)); |
| } |
| |
| // ------------------------------ Undefined |
| |
| #define HWY_SVE_UNDEFINED(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| template <size_t N, int kPow2> \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */) { \ |
| return sv##OP##_##CHAR##BITS(); \ |
| } |
| |
| HWY_SVE_FOREACH(HWY_SVE_UNDEFINED, Undefined, undef) |
| #if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC |
| HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_UNDEFINED, Undefined, undef) |
| #endif |
| |
| template <class D, HWY_SVE_IF_EMULATED_D(D)> |
| VFromD<D> Undefined(D d) { |
| const RebindToUnsigned<D> du; |
| return BitCast(d, Undefined(du)); |
| } |
| |
| // ------------------------------ Tuple |
| |
| // tuples = f(d, v..), e.g. Create2 |
| #define HWY_SVE_CREATE(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| template <size_t N, int kPow2> \ |
| HWY_API HWY_SVE_TUPLE(BASE, BITS, 2) \ |
| NAME##2(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ |
| HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1) { \ |
| return sv##OP##2_##CHAR##BITS(v0, v1); \ |
| } \ |
| template <size_t N, int kPow2> \ |
| HWY_API HWY_SVE_TUPLE(BASE, BITS, 3) NAME##3( \ |
| HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, HWY_SVE_V(BASE, BITS) v0, \ |
| HWY_SVE_V(BASE, BITS) v1, HWY_SVE_V(BASE, BITS) v2) { \ |
| return sv##OP##3_##CHAR##BITS(v0, v1, v2); \ |
| } \ |
| template <size_t N, int kPow2> \ |
| HWY_API HWY_SVE_TUPLE(BASE, BITS, 4) \ |
| NAME##4(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ |
| HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1, \ |
| HWY_SVE_V(BASE, BITS) v2, HWY_SVE_V(BASE, BITS) v3) { \ |
| return sv##OP##4_##CHAR##BITS(v0, v1, v2, v3); \ |
| } |
| |
| HWY_SVE_FOREACH(HWY_SVE_CREATE, Create, create) |
| #if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC |
| HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_CREATE, Create, create) |
| #endif |
| #undef HWY_SVE_CREATE |
| |
| template <class D> |
| using Vec2 = decltype(Create2(D(), Zero(D()), Zero(D()))); |
| template <class D> |
| using Vec3 = decltype(Create3(D(), Zero(D()), Zero(D()), Zero(D()))); |
| template <class D> |
| using Vec4 = decltype(Create4(D(), Zero(D()), Zero(D()), Zero(D()), Zero(D()))); |
| |
| #define HWY_SVE_GET(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| template <size_t kIndex> \ |
| HWY_API HWY_SVE_V(BASE, BITS) NAME##2(HWY_SVE_TUPLE(BASE, BITS, 2) tuple) { \ |
| return sv##OP##2_##CHAR##BITS(tuple, kIndex); \ |
| } \ |
| template <size_t kIndex> \ |
| HWY_API HWY_SVE_V(BASE, BITS) NAME##3(HWY_SVE_TUPLE(BASE, BITS, 3) tuple) { \ |
| return sv##OP##3_##CHAR##BITS(tuple, kIndex); \ |
| } \ |
| template <size_t kIndex> \ |
| HWY_API HWY_SVE_V(BASE, BITS) NAME##4(HWY_SVE_TUPLE(BASE, BITS, 4) tuple) { \ |
| return sv##OP##4_##CHAR##BITS(tuple, kIndex); \ |
| } |
| |
| HWY_SVE_FOREACH(HWY_SVE_GET, Get, get) |
| #if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC |
| HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_GET, Get, get) |
| #endif |
| #undef HWY_SVE_GET |
| |
| #define HWY_SVE_SET(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| template <size_t kIndex> \ |
| HWY_API HWY_SVE_TUPLE(BASE, BITS, 2) \ |
| NAME##2(HWY_SVE_TUPLE(BASE, BITS, 2) tuple, HWY_SVE_V(BASE, BITS) vec) { \ |
| return sv##OP##2_##CHAR##BITS(tuple, kIndex, vec); \ |
| } \ |
| template <size_t kIndex> \ |
| HWY_API HWY_SVE_TUPLE(BASE, BITS, 3) \ |
| NAME##3(HWY_SVE_TUPLE(BASE, BITS, 3) tuple, HWY_SVE_V(BASE, BITS) vec) { \ |
| return sv##OP##3_##CHAR##BITS(tuple, kIndex, vec); \ |
| } \ |
| template <size_t kIndex> \ |
| HWY_API HWY_SVE_TUPLE(BASE, BITS, 4) \ |
| NAME##4(HWY_SVE_TUPLE(BASE, BITS, 4) tuple, HWY_SVE_V(BASE, BITS) vec) { \ |
| return sv##OP##4_##CHAR##BITS(tuple, kIndex, vec); \ |
| } |
| |
| HWY_SVE_FOREACH(HWY_SVE_SET, Set, set) |
| #if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC |
| HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_SET, Set, set) |
| #endif |
| #undef HWY_SVE_SET |
| |
| // ------------------------------ ResizeBitCast |
| |
| // Same as BitCast on SVE |
| template <class D, class FromV> |
| HWY_API VFromD<D> ResizeBitCast(D d, FromV v) { |
| return BitCast(d, v); |
| } |
| |
| // ------------------------------ Dup128VecFromValues |
| |
| template <class D, HWY_IF_I8_D(D)> |
| HWY_API svint8_t Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1, |
| TFromD<D> t2, TFromD<D> t3, TFromD<D> t4, |
| TFromD<D> t5, TFromD<D> t6, TFromD<D> t7, |
| TFromD<D> t8, TFromD<D> t9, TFromD<D> t10, |
| TFromD<D> t11, TFromD<D> t12, |
| TFromD<D> t13, TFromD<D> t14, |
| TFromD<D> t15) { |
| return svdupq_n_s8(t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, |
| t14, t15); |
| } |
| |
| template <class D, HWY_IF_U8_D(D)> |
| HWY_API svuint8_t Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1, |
| TFromD<D> t2, TFromD<D> t3, TFromD<D> t4, |
| TFromD<D> t5, TFromD<D> t6, TFromD<D> t7, |
| TFromD<D> t8, TFromD<D> t9, TFromD<D> t10, |
| TFromD<D> t11, TFromD<D> t12, |
| TFromD<D> t13, TFromD<D> t14, |
| TFromD<D> t15) { |
| return svdupq_n_u8(t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, |
| t14, t15); |
| } |
| |
| template <class D, HWY_IF_I16_D(D)> |
| HWY_API svint16_t Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1, |
| TFromD<D> t2, TFromD<D> t3, TFromD<D> t4, |
| TFromD<D> t5, TFromD<D> t6, |
| TFromD<D> t7) { |
| return svdupq_n_s16(t0, t1, t2, t3, t4, t5, t6, t7); |
| } |
| |
| template <class D, HWY_IF_U16_D(D)> |
| HWY_API svuint16_t Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1, |
| TFromD<D> t2, TFromD<D> t3, TFromD<D> t4, |
| TFromD<D> t5, TFromD<D> t6, |
| TFromD<D> t7) { |
| return svdupq_n_u16(t0, t1, t2, t3, t4, t5, t6, t7); |
| } |
| |
| template <class D, HWY_IF_F16_D(D)> |
| HWY_API svfloat16_t Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1, |
| TFromD<D> t2, TFromD<D> t3, |
| TFromD<D> t4, TFromD<D> t5, |
| TFromD<D> t6, TFromD<D> t7) { |
| return svdupq_n_f16(t0, t1, t2, t3, t4, t5, t6, t7); |
| } |
| |
| template <class D, HWY_IF_BF16_D(D)> |
| HWY_API VBF16 Dup128VecFromValues(D d, TFromD<D> t0, TFromD<D> t1, TFromD<D> t2, |
| TFromD<D> t3, TFromD<D> t4, TFromD<D> t5, |
| TFromD<D> t6, TFromD<D> t7) { |
| #if HWY_SVE_HAVE_BF16_FEATURE |
| (void)d; |
| return svdupq_n_bf16(t0, t1, t2, t3, t4, t5, t6, t7); |
| #else |
| const RebindToUnsigned<decltype(d)> du; |
| return BitCast( |
| d, Dup128VecFromValues( |
| du, BitCastScalar<uint16_t>(t0), BitCastScalar<uint16_t>(t1), |
| BitCastScalar<uint16_t>(t2), BitCastScalar<uint16_t>(t3), |
| BitCastScalar<uint16_t>(t4), BitCastScalar<uint16_t>(t5), |
| BitCastScalar<uint16_t>(t6), BitCastScalar<uint16_t>(t7))); |
| #endif |
| } |
| |
| template <class D, HWY_IF_I32_D(D)> |
| HWY_API svint32_t Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1, |
| TFromD<D> t2, TFromD<D> t3) { |
| return svdupq_n_s32(t0, t1, t2, t3); |
| } |
| |
| template <class D, HWY_IF_U32_D(D)> |
| HWY_API svuint32_t Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1, |
| TFromD<D> t2, TFromD<D> t3) { |
| return svdupq_n_u32(t0, t1, t2, t3); |
| } |
| |
| template <class D, HWY_IF_F32_D(D)> |
| HWY_API svfloat32_t Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1, |
| TFromD<D> t2, TFromD<D> t3) { |
| return svdupq_n_f32(t0, t1, t2, t3); |
| } |
| |
| template <class D, HWY_IF_I64_D(D)> |
| HWY_API svint64_t Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1) { |
| return svdupq_n_s64(t0, t1); |
| } |
| |
| template <class D, HWY_IF_U64_D(D)> |
| HWY_API svuint64_t Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1) { |
| return svdupq_n_u64(t0, t1); |
| } |
| |
| template <class D, HWY_IF_F64_D(D)> |
| HWY_API svfloat64_t Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1) { |
| return svdupq_n_f64(t0, t1); |
| } |
| |
| // ------------------------------ GetLane |
| |
| namespace detail { |
| #define HWY_SVE_GET_LANE(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_INLINE HWY_SVE_T(BASE, BITS) \ |
| NAME(HWY_SVE_V(BASE, BITS) v, svbool_t mask) { \ |
| return sv##OP##_##CHAR##BITS(mask, v); \ |
| } |
| |
| HWY_SVE_FOREACH(HWY_SVE_GET_LANE, GetLaneM, lasta) |
| HWY_SVE_FOREACH(HWY_SVE_GET_LANE, ExtractLastMatchingLaneM, lastb) |
| #undef HWY_SVE_GET_LANE |
| } // namespace detail |
| |
| template <class V> |
| HWY_API TFromV<V> GetLane(V v) { |
| return detail::GetLaneM(v, detail::PFalse()); |
| } |
| |
| // ================================================== LOGICAL |
| |
| // detail::*N() functions accept a scalar argument to avoid extra Set(). |
| |
| // ------------------------------ Not |
| HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPV, Not, not ) // NOLINT |
| |
| // ------------------------------ And |
| |
| namespace detail { |
| HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, AndN, and_n) |
| } // namespace detail |
| |
| HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, And, and) |
| |
| template <class V, HWY_IF_FLOAT_V(V)> |
| HWY_API V And(const V a, const V b) { |
| const DFromV<V> df; |
| const RebindToUnsigned<decltype(df)> du; |
| return BitCast(df, And(BitCast(du, a), BitCast(du, b))); |
| } |
| |
| // ------------------------------ Or |
| |
| namespace detail { |
| HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, OrN, orr_n) |
| } // namespace detail |
| |
| HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Or, orr) |
| |
| template <class V, HWY_IF_FLOAT_V(V)> |
| HWY_API V Or(const V a, const V b) { |
| const DFromV<V> df; |
| const RebindToUnsigned<decltype(df)> du; |
| return BitCast(df, Or(BitCast(du, a), BitCast(du, b))); |
| } |
| |
| // ------------------------------ MaskedOr |
| HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGMVV_Z, MaskedOr, orr) |
| |
| // ------------------------------ Xor |
| |
| namespace detail { |
| HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, XorN, eor_n) |
| } // namespace detail |
| |
| HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Xor, eor) |
| |
| template <class V, HWY_IF_FLOAT_V(V)> |
| HWY_API V Xor(const V a, const V b) { |
| const DFromV<V> df; |
| const RebindToUnsigned<decltype(df)> du; |
| return BitCast(df, Xor(BitCast(du, a), BitCast(du, b))); |
| } |
| |
| // ------------------------------ AndNot |
| |
| namespace detail { |
| #define HWY_SVE_RETV_ARGPVN_SWAP(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| NAME(HWY_SVE_T(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ |
| return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), b, a); \ |
| } |
| |
| HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN_SWAP, AndNotN, bic_n) |
| #undef HWY_SVE_RETV_ARGPVN_SWAP |
| } // namespace detail |
| |
| #define HWY_SVE_RETV_ARGPVV_SWAP(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ |
| return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), b, a); \ |
| } |
| HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV_SWAP, AndNot, bic) |
| #undef HWY_SVE_RETV_ARGPVV_SWAP |
| |
| template <class V, HWY_IF_FLOAT_V(V)> |
| HWY_API V AndNot(const V a, const V b) { |
| const DFromV<V> df; |
| const RebindToUnsigned<decltype(df)> du; |
| return BitCast(df, AndNot(BitCast(du, a), BitCast(du, b))); |
| } |
| |
| // ------------------------------ Xor3 |
| |
| #if HWY_SVE_HAVE_2 |
| |
| HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGVVV, Xor3, eor3) |
| |
| template <class V, HWY_IF_FLOAT_V(V)> |
| HWY_API V Xor3(const V x1, const V x2, const V x3) { |
| const DFromV<V> df; |
| const RebindToUnsigned<decltype(df)> du; |
| return BitCast(df, Xor3(BitCast(du, x1), BitCast(du, x2), BitCast(du, x3))); |
| } |
| |
| #else |
| template <class V> |
| HWY_API V Xor3(V x1, V x2, V x3) { |
| return Xor(x1, Xor(x2, x3)); |
| } |
| #endif |
| |
| // ------------------------------ Or3 |
| template <class V> |
| HWY_API V Or3(V o1, V o2, V o3) { |
| return Or(o1, Or(o2, o3)); |
| } |
| |
| // ------------------------------ OrAnd |
| template <class V> |
| HWY_API V OrAnd(const V o, const V a1, const V a2) { |
| return Or(o, And(a1, a2)); |
| } |
| |
| // ------------------------------ PopulationCount |
| |
| #ifdef HWY_NATIVE_POPCNT |
| #undef HWY_NATIVE_POPCNT |
| #else |
| #define HWY_NATIVE_POPCNT |
| #endif |
| |
| // Need to return original type instead of unsigned. |
| #define HWY_SVE_POPCNT(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ |
| return BitCast(DFromV<decltype(v)>(), \ |
| sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v)); \ |
| } |
| HWY_SVE_FOREACH_UI(HWY_SVE_POPCNT, PopulationCount, cnt) |
| #undef HWY_SVE_POPCNT |
| |
| // ================================================== SIGN |
| |
| // ------------------------------ Neg |
| HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGPV, Neg, neg) |
| |
| HWY_API VBF16 Neg(VBF16 v) { |
| const DFromV<decltype(v)> d; |
| const RebindToUnsigned<decltype(d)> du; |
| using TU = TFromD<decltype(du)>; |
| return BitCast(d, Xor(BitCast(du, v), Set(du, SignMask<TU>()))); |
| } |
| |
| // ------------------------------ SaturatedNeg |
| #if HWY_SVE_HAVE_2 |
| #ifdef HWY_NATIVE_SATURATED_NEG_8_16_32 |
| #undef HWY_NATIVE_SATURATED_NEG_8_16_32 |
| #else |
| #define HWY_NATIVE_SATURATED_NEG_8_16_32 |
| #endif |
| |
| #ifdef HWY_NATIVE_SATURATED_NEG_64 |
| #undef HWY_NATIVE_SATURATED_NEG_64 |
| #else |
| #define HWY_NATIVE_SATURATED_NEG_64 |
| #endif |
| |
| HWY_SVE_FOREACH_I(HWY_SVE_RETV_ARGPV, SaturatedNeg, qneg) |
| #endif // HWY_SVE_HAVE_2 |
| |
| // ================================================== ARITHMETIC |
| |
| // Per-target flags to prevent generic_ops-inl.h defining Add etc. |
| #ifdef HWY_NATIVE_OPERATOR_REPLACEMENTS |
| #undef HWY_NATIVE_OPERATOR_REPLACEMENTS |
| #else |
| #define HWY_NATIVE_OPERATOR_REPLACEMENTS |
| #endif |
| |
| // ------------------------------ Add |
| |
| namespace detail { |
| HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVN, AddN, add_n) |
| } // namespace detail |
| |
| HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV, Add, add) |
| |
| // ------------------------------ Sub |
| |
| namespace detail { |
| // Can't use HWY_SVE_RETV_ARGPVN because caller wants to specify pg. |
| #define HWY_SVE_RETV_ARGPVN_MASK(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| NAME(svbool_t pg, HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \ |
| return sv##OP##_##CHAR##BITS##_z(pg, a, b); \ |
| } |
| |
| HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVN_MASK, SubN, sub_n) |
| #undef HWY_SVE_RETV_ARGPVN_MASK |
| } // namespace detail |
| |
| HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV, Sub, sub) |
| |
| // ------------------------------ SumsOf8 |
| HWY_API svuint64_t SumsOf8(const svuint8_t v) { |
| const ScalableTag<uint32_t> du32; |
| const ScalableTag<uint64_t> du64; |
| const svbool_t pg = detail::PTrue(du64); |
| |
| const svuint32_t sums_of_4 = svdot_n_u32(Zero(du32), v, 1); |
| // Compute pairwise sum of u32 and extend to u64. |
| |
| #if HWY_SVE_HAVE_2 |
| return svadalp_u64_x(pg, Zero(du64), sums_of_4); |
| #else |
| const svuint64_t hi = svlsr_n_u64_x(pg, BitCast(du64, sums_of_4), 32); |
| // Isolate the lower 32 bits (to be added to the upper 32 and zero-extended) |
| const svuint64_t lo = svextw_u64_x(pg, BitCast(du64, sums_of_4)); |
| return Add(hi, lo); |
| #endif |
| } |
| |
| HWY_API svint64_t SumsOf8(const svint8_t v) { |
| const ScalableTag<int32_t> di32; |
| const ScalableTag<int64_t> di64; |
| const svbool_t pg = detail::PTrue(di64); |
| |
| const svint32_t sums_of_4 = svdot_n_s32(Zero(di32), v, 1); |
| #if HWY_SVE_HAVE_2 |
| return svadalp_s64_x(pg, Zero(di64), sums_of_4); |
| #else |
| const svint64_t hi = svasr_n_s64_x(pg, BitCast(di64, sums_of_4), 32); |
| // Isolate the lower 32 bits (to be added to the upper 32 and sign-extended) |
| const svint64_t lo = svextw_s64_x(pg, BitCast(di64, sums_of_4)); |
| return Add(hi, lo); |
| #endif |
| } |
| |
| // ------------------------------ SumsOf2 |
| #if HWY_SVE_HAVE_2 |
| namespace detail { |
| |
| HWY_INLINE svint16_t SumsOf2(hwy::SignedTag /*type_tag*/, |
| hwy::SizeTag<1> /*lane_size_tag*/, svint8_t v) { |
| const ScalableTag<int16_t> di16; |
| const svbool_t pg = detail::PTrue(di16); |
| return svadalp_s16_x(pg, Zero(di16), v); |
| } |
| |
| HWY_INLINE svuint16_t SumsOf2(hwy::UnsignedTag /*type_tag*/, |
| hwy::SizeTag<1> /*lane_size_tag*/, svuint8_t v) { |
| const ScalableTag<uint16_t> du16; |
| const svbool_t pg = detail::PTrue(du16); |
| return svadalp_u16_x(pg, Zero(du16), v); |
| } |
| |
| HWY_INLINE svint32_t SumsOf2(hwy::SignedTag /*type_tag*/, |
| hwy::SizeTag<2> /*lane_size_tag*/, svint16_t v) { |
| const ScalableTag<int32_t> di32; |
| const svbool_t pg = detail::PTrue(di32); |
| return svadalp_s32_x(pg, Zero(di32), v); |
| } |
| |
| HWY_INLINE svuint32_t SumsOf2(hwy::UnsignedTag /*type_tag*/, |
| hwy::SizeTag<2> /*lane_size_tag*/, svuint16_t v) { |
| const ScalableTag<uint32_t> du32; |
| const svbool_t pg = detail::PTrue(du32); |
| return svadalp_u32_x(pg, Zero(du32), v); |
| } |
| |
| HWY_INLINE svint64_t SumsOf2(hwy::SignedTag /*type_tag*/, |
| hwy::SizeTag<4> /*lane_size_tag*/, svint32_t v) { |
| const ScalableTag<int64_t> di64; |
| const svbool_t pg = detail::PTrue(di64); |
| return svadalp_s64_x(pg, Zero(di64), v); |
| } |
| |
| HWY_INLINE svuint64_t SumsOf2(hwy::UnsignedTag /*type_tag*/, |
| hwy::SizeTag<4> /*lane_size_tag*/, svuint32_t v) { |
| const ScalableTag<uint64_t> du64; |
| const svbool_t pg = detail::PTrue(du64); |
| return svadalp_u64_x(pg, Zero(du64), v); |
| } |
| |
| } // namespace detail |
| #endif // HWY_SVE_HAVE_2 |
| |
| // ------------------------------ SumsOf4 |
| namespace detail { |
| |
| HWY_INLINE svint32_t SumsOf4(hwy::SignedTag /*type_tag*/, |
| hwy::SizeTag<1> /*lane_size_tag*/, svint8_t v) { |
| return svdot_n_s32(Zero(ScalableTag<int32_t>()), v, 1); |
| } |
| |
| HWY_INLINE svuint32_t SumsOf4(hwy::UnsignedTag /*type_tag*/, |
| hwy::SizeTag<1> /*lane_size_tag*/, svuint8_t v) { |
| return svdot_n_u32(Zero(ScalableTag<uint32_t>()), v, 1); |
| } |
| |
| HWY_INLINE svint64_t SumsOf4(hwy::SignedTag /*type_tag*/, |
| hwy::SizeTag<2> /*lane_size_tag*/, svint16_t v) { |
| return svdot_n_s64(Zero(ScalableTag<int64_t>()), v, 1); |
| } |
| |
| HWY_INLINE svuint64_t SumsOf4(hwy::UnsignedTag /*type_tag*/, |
| hwy::SizeTag<2> /*lane_size_tag*/, svuint16_t v) { |
| return svdot_n_u64(Zero(ScalableTag<uint64_t>()), v, 1); |
| } |
| |
| } // namespace detail |
| |
| // ------------------------------ SaturatedAdd |
| |
| #ifdef HWY_NATIVE_I32_SATURATED_ADDSUB |
| #undef HWY_NATIVE_I32_SATURATED_ADDSUB |
| #else |
| #define HWY_NATIVE_I32_SATURATED_ADDSUB |
| #endif |
| |
| #ifdef HWY_NATIVE_U32_SATURATED_ADDSUB |
| #undef HWY_NATIVE_U32_SATURATED_ADDSUB |
| #else |
| #define HWY_NATIVE_U32_SATURATED_ADDSUB |
| #endif |
| |
| #ifdef HWY_NATIVE_I64_SATURATED_ADDSUB |
| #undef HWY_NATIVE_I64_SATURATED_ADDSUB |
| #else |
| #define HWY_NATIVE_I64_SATURATED_ADDSUB |
| #endif |
| |
| #ifdef HWY_NATIVE_U64_SATURATED_ADDSUB |
| #undef HWY_NATIVE_U64_SATURATED_ADDSUB |
| #else |
| #define HWY_NATIVE_U64_SATURATED_ADDSUB |
| #endif |
| |
| HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGVV, SaturatedAdd, qadd) |
| |
| // ------------------------------ SaturatedSub |
| |
| HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGVV, SaturatedSub, qsub) |
| |
| // ------------------------------ AbsDiff |
| #ifdef HWY_NATIVE_INTEGER_ABS_DIFF |
| #undef HWY_NATIVE_INTEGER_ABS_DIFF |
| #else |
| #define HWY_NATIVE_INTEGER_ABS_DIFF |
| #endif |
| |
| HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV, AbsDiff, abd) |
| |
| // ------------------------------ ShiftLeft[Same] |
| |
| #define HWY_SVE_SHIFT_N(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| template <int kBits> \ |
| HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ |
| return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v, kBits); \ |
| } \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| NAME##Same(HWY_SVE_V(BASE, BITS) v, int bits) { \ |
| return sv##OP##_##CHAR##BITS##_x( \ |
| HWY_SVE_PTRUE(BITS), v, static_cast<HWY_SVE_T(uint, BITS)>(bits)); \ |
| } |
| |
| HWY_SVE_FOREACH_UI(HWY_SVE_SHIFT_N, ShiftLeft, lsl_n) |
| |
| // ------------------------------ ShiftRight[Same] |
| |
| HWY_SVE_FOREACH_U(HWY_SVE_SHIFT_N, ShiftRight, lsr_n) |
| HWY_SVE_FOREACH_I(HWY_SVE_SHIFT_N, ShiftRight, asr_n) |
| |
| #undef HWY_SVE_SHIFT_N |
| |
| // ------------------------------ MaskedShift[Left/Right] |
| |
| #define HWY_SVE_SHIFT_Z(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| template <int kBits> \ |
| HWY_API HWY_SVE_V(BASE, BITS) NAME(svbool_t m, HWY_SVE_V(BASE, BITS) v) { \ |
| auto shifts = static_cast<HWY_SVE_T(uint, BITS)>(kBits); \ |
| return sv##OP##_##CHAR##BITS##_z(m, v, shifts); \ |
| } |
| HWY_SVE_FOREACH_UI(HWY_SVE_SHIFT_Z, MaskedShiftLeft, lsl_n) |
| HWY_SVE_FOREACH_I(HWY_SVE_SHIFT_Z, MaskedShiftRight, asr_n) |
| HWY_SVE_FOREACH_U(HWY_SVE_SHIFT_Z, MaskedShiftRight, lsr_n) |
| |
| #undef HWY_SVE_SHIFT_Z |
| |
| // ------------------------------ MaskedShiftRightOr |
| |
| #define HWY_SVE_SHIFT_OR(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| template <int kBits> \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| NAME(HWY_SVE_V(BASE, BITS) no, svbool_t m, HWY_SVE_V(BASE, BITS) v) { \ |
| auto shifts = static_cast<HWY_SVE_T(uint, BITS)>(kBits); \ |
| return svsel##_##CHAR##BITS(m, sv##OP##_##CHAR##BITS##_z(m, v, shifts), \ |
| no); \ |
| } |
| HWY_SVE_FOREACH_I(HWY_SVE_SHIFT_OR, MaskedShiftRightOr, asr_n) |
| HWY_SVE_FOREACH_U(HWY_SVE_SHIFT_OR, MaskedShiftRightOr, lsr_n) |
| |
| #undef HWY_SVE_SHIFT_OR |
| |
| // ------------------------------ RotateRight |
| |
| #if HWY_SVE_HAVE_2 |
| |
| #define HWY_SVE_ROTATE_RIGHT_N(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| template <int kBits> \ |
| HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ |
| if (kBits == 0) return v; \ |
| return sv##OP##_##CHAR##BITS(v, Zero(DFromV<decltype(v)>()), \ |
| HWY_MAX(kBits, 1)); \ |
| } |
| |
| HWY_SVE_FOREACH_U(HWY_SVE_ROTATE_RIGHT_N, RotateRight, xar_n) |
| HWY_SVE_FOREACH_I(HWY_SVE_ROTATE_RIGHT_N, RotateRight, xar_n) |
| |
| #undef HWY_SVE_ROTATE_RIGHT_N |
| |
| #else // !HWY_SVE_HAVE_2 |
| template <int kBits, class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V)> |
| HWY_API V RotateRight(const V v) { |
| const DFromV<decltype(v)> d; |
| const RebindToUnsigned<decltype(d)> du; |
| |
| constexpr size_t kSizeInBits = sizeof(TFromV<V>) * 8; |
| static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); |
| if (kBits == 0) return v; |
| |
| return Or(BitCast(d, ShiftRight<kBits>(BitCast(du, v))), |
| ShiftLeft<HWY_MIN(kSizeInBits - 1, kSizeInBits - kBits)>(v)); |
| } |
| #endif |
| |
| // ------------------------------ Shl, Shr |
| |
| #define HWY_SVE_SHIFT(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_V(BASE, BITS) bits) { \ |
| const RebindToUnsigned<DFromV<decltype(v)>> du; \ |
| return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v, \ |
| BitCast(du, bits)); \ |
| } |
| |
| HWY_SVE_FOREACH_UI(HWY_SVE_SHIFT, Shl, lsl) |
| |
| HWY_SVE_FOREACH_U(HWY_SVE_SHIFT, Shr, lsr) |
| HWY_SVE_FOREACH_I(HWY_SVE_SHIFT, Shr, asr) |
| |
| #undef HWY_SVE_SHIFT |
| |
| // ------------------------------ RoundingShiftLeft[Same]/RoundingShr |
| |
| #if HWY_SVE_HAVE_2 |
| |
| #ifdef HWY_NATIVE_ROUNDING_SHR |
| #undef HWY_NATIVE_ROUNDING_SHR |
| #else |
| #define HWY_NATIVE_ROUNDING_SHR |
| #endif |
| |
| #define HWY_SVE_ROUNDING_SHR_N(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| template <int kBits> \ |
| HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ |
| HWY_IF_CONSTEXPR(kBits == 0) { return v; } \ |
| \ |
| return sv##OP##_##CHAR##BITS##_x( \ |
| HWY_SVE_PTRUE(BITS), v, static_cast<uint64_t>(HWY_MAX(kBits, 1))); \ |
| } |
| |
| HWY_SVE_FOREACH_UI(HWY_SVE_ROUNDING_SHR_N, RoundingShiftRight, rshr_n) |
| |
| #undef HWY_SVE_ROUNDING_SHR_N |
| |
| #define HWY_SVE_ROUNDING_SHR(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_V(BASE, BITS) bits) { \ |
| const RebindToSigned<DFromV<decltype(v)>> di; \ |
| return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v, \ |
| Neg(BitCast(di, bits))); \ |
| } |
| |
| HWY_SVE_FOREACH_UI(HWY_SVE_ROUNDING_SHR, RoundingShr, rshl) |
| |
| #undef HWY_SVE_ROUNDING_SHR |
| |
| template <class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V)> |
| HWY_API V RoundingShiftRightSame(V v, int bits) { |
| const DFromV<V> d; |
| using T = TFromD<decltype(d)>; |
| return RoundingShr(v, Set(d, static_cast<T>(bits))); |
| } |
| |
| #endif // HWY_SVE_HAVE_2 |
| |
| // ------------------------------ BroadcastSignBit (ShiftRight) |
| template <class V> |
| HWY_API V BroadcastSignBit(const V v) { |
| return ShiftRight<sizeof(TFromV<V>) * 8 - 1>(v); |
| } |
| |
| // ------------------------------ Abs (ShiftRight, Add, Xor, AndN) |
| |
| // Workaround for incorrect results with `svabs`. |
| #if HWY_COMPILER_CLANG |
| template <class V, HWY_IF_SIGNED_V(V)> |
| HWY_API V Abs(V v) { |
| const V sign = BroadcastSignBit(v); |
| return Xor(Add(v, sign), sign); |
| } |
| |
| template <class V, HWY_IF_FLOAT_OR_SPECIAL_V(V)> |
| HWY_NOINLINE V Abs(V v) { |
| const DFromV<V> d; |
| const RebindToUnsigned<decltype(d)> du; |
| using TU = MakeUnsigned<TFromD<decltype(d)>>; |
| return BitCast( |
| d, detail::AndN(BitCast(du, v), static_cast<TU>(~SignMask<TU>()))); |
| } |
| |
| #else |
| HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGPV, Abs, abs) |
| #endif |
| |
| // ------------------------------ SaturatedAbs |
| #if HWY_SVE_HAVE_2 |
| #ifdef HWY_NATIVE_SATURATED_ABS |
| #undef HWY_NATIVE_SATURATED_ABS |
| #else |
| #define HWY_NATIVE_SATURATED_ABS |
| #endif |
| |
| HWY_SVE_FOREACH_I(HWY_SVE_RETV_ARGPV, SaturatedAbs, qabs) |
| #endif // HWY_SVE_HAVE_2 |
| |
| // ------------------------------ MaskedAbsOr |
| HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGMV_M, MaskedAbsOr, abs) |
| |
| // ------------------------------ MaskedAbs |
| HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGMV_Z, MaskedAbs, abs) |
| |
| // ------------------------------ Mul |
| |
| // Per-target flags to prevent generic_ops-inl.h defining 8/64-bit operator*. |
| #ifdef HWY_NATIVE_MUL_8 |
| #undef HWY_NATIVE_MUL_8 |
| #else |
| #define HWY_NATIVE_MUL_8 |
| #endif |
| #ifdef HWY_NATIVE_MUL_64 |
| #undef HWY_NATIVE_MUL_64 |
| #else |
| #define HWY_NATIVE_MUL_64 |
| #endif |
| |
| HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV, Mul, mul) |
| |
| // ------------------------------ MulHigh |
| HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, MulHigh, mulh) |
| |
| // ------------------------------ MulFixedPoint15 |
| HWY_API svint16_t MulFixedPoint15(svint16_t a, svint16_t b) { |
| #if HWY_SVE_HAVE_2 |
| return svqrdmulh_s16(a, b); |
| #else |
| const DFromV<decltype(a)> d; |
| const RebindToUnsigned<decltype(d)> du; |
| |
| const svuint16_t lo = BitCast(du, Mul(a, b)); |
| const svint16_t hi = MulHigh(a, b); |
| // We want (lo + 0x4000) >> 15, but that can overflow, and if it does we must |
| // carry that into the result. Instead isolate the top two bits because only |
| // they can influence the result. |
| const svuint16_t lo_top2 = ShiftRight<14>(lo); |
| // Bits 11: add 2, 10: add 1, 01: add 1, 00: add 0. |
| const svuint16_t rounding = ShiftRight<1>(detail::AddN(lo_top2, 1)); |
| return Add(Add(hi, hi), BitCast(d, rounding)); |
| #endif |
| } |
| |
| // ------------------------------ Div |
| #ifdef HWY_NATIVE_INT_DIV |
| #undef HWY_NATIVE_INT_DIV |
| #else |
| #define HWY_NATIVE_INT_DIV |
| #endif |
| |
| HWY_SVE_FOREACH_UI32(HWY_SVE_RETV_ARGPVV, Div, div) |
| HWY_SVE_FOREACH_UI64(HWY_SVE_RETV_ARGPVV, Div, div) |
| HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPVV, Div, div) |
| |
| // ------------------------------ ApproximateReciprocal |
| #ifdef HWY_NATIVE_F64_APPROX_RECIP |
| #undef HWY_NATIVE_F64_APPROX_RECIP |
| #else |
| #define HWY_NATIVE_F64_APPROX_RECIP |
| #endif |
| |
| HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGV, ApproximateReciprocal, recpe) |
| |
| // ------------------------------ Sqrt |
| HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Sqrt, sqrt) |
| |
| // ------------------------------ MaskedSqrt |
| #ifdef HWY_NATIVE_MASKED_SQRT |
| #undef HWY_NATIVE_MASKED_SQRT |
| #else |
| #define HWY_NATIVE_MASKED_SQRT |
| #endif |
| |
| HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGMV_Z, MaskedSqrt, sqrt) |
| |
| // ------------------------------ ApproximateReciprocalSqrt |
| #ifdef HWY_NATIVE_F64_APPROX_RSQRT |
| #undef HWY_NATIVE_F64_APPROX_RSQRT |
| #else |
| #define HWY_NATIVE_F64_APPROX_RSQRT |
| #endif |
| |
| HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGV, ApproximateReciprocalSqrt, rsqrte) |
| |
| // ------------------------------ MulAdd |
| |
| // Per-target flag to prevent generic_ops-inl.h from defining int MulAdd. |
| #ifdef HWY_NATIVE_INT_FMA |
| #undef HWY_NATIVE_INT_FMA |
| #else |
| #define HWY_NATIVE_INT_FMA |
| #endif |
| |
| #define HWY_SVE_FMA(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| NAME(HWY_SVE_V(BASE, BITS) mul, HWY_SVE_V(BASE, BITS) x, \ |
| HWY_SVE_V(BASE, BITS) add) { \ |
| return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), x, mul, add); \ |
| } |
| |
| HWY_SVE_FOREACH(HWY_SVE_FMA, MulAdd, mad) |
| |
| // ------------------------------ NegMulAdd |
| HWY_SVE_FOREACH(HWY_SVE_FMA, NegMulAdd, msb) |
| |
| // ------------------------------ MulSub |
| HWY_SVE_FOREACH_F(HWY_SVE_FMA, MulSub, nmsb) |
| |
| // ------------------------------ NegMulSub |
| HWY_SVE_FOREACH_F(HWY_SVE_FMA, NegMulSub, nmad) |
| |
| #undef HWY_SVE_FMA |
| |
| // ------------------------------ Round etc. |
| |
| HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Round, rintn) |
| HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Floor, rintm) |
| HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Ceil, rintp) |
| HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Trunc, rintz) |
| |
| // ================================================== MASK |
| |
| // ------------------------------ RebindMask |
| template <class D, typename MFrom> |
| HWY_API svbool_t RebindMask(const D /*d*/, const MFrom mask) { |
| return mask; |
| } |
| |
| // ------------------------------ Mask logical |
| |
| HWY_API svbool_t Not(svbool_t m) { |
| // We don't know the lane type, so assume 8-bit. For larger types, this will |
| // de-canonicalize the predicate, i.e. set bits to 1 even though they do not |
| // correspond to the lowest byte in the lane. Arm says such bits are ignored. |
| return svnot_b_z(HWY_SVE_PTRUE(8), m); |
| } |
| HWY_API svbool_t And(svbool_t a, svbool_t b) { |
| return svand_b_z(b, b, a); // same order as AndNot for consistency |
| } |
| HWY_API svbool_t AndNot(svbool_t a, svbool_t b) { |
| return svbic_b_z(b, b, a); // reversed order like NEON |
| } |
| HWY_API svbool_t Or(svbool_t a, svbool_t b) { |
| return svsel_b(a, a, b); // a ? true : b |
| } |
| HWY_API svbool_t Xor(svbool_t a, svbool_t b) { |
| return svsel_b(a, svnand_b_z(a, a, b), b); // a ? !(a & b) : b. |
| } |
| |
| HWY_API svbool_t ExclusiveNeither(svbool_t a, svbool_t b) { |
| return svnor_b_z(HWY_SVE_PTRUE(8), a, b); // !a && !b, undefined if a && b. |
| } |
| |
| // ------------------------------ CountTrue |
| |
| #define HWY_SVE_COUNT_TRUE(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| template <size_t N, int kPow2> \ |
| HWY_API size_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, svbool_t m) { \ |
| return sv##OP##_b##BITS(detail::MakeMask(d), m); \ |
| } |
| |
| HWY_SVE_FOREACH(HWY_SVE_COUNT_TRUE, CountTrue, cntp) |
| #undef HWY_SVE_COUNT_TRUE |
| |
| // For 16-bit Compress: full vector, not limited to SV_POW2. |
| namespace detail { |
| |
| #define HWY_SVE_COUNT_TRUE_FULL(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| template <size_t N, int kPow2> \ |
| HWY_API size_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, svbool_t m) { \ |
| return sv##OP##_b##BITS(svptrue_b##BITS(), m); \ |
| } |
| |
| HWY_SVE_FOREACH(HWY_SVE_COUNT_TRUE_FULL, CountTrueFull, cntp) |
| #undef HWY_SVE_COUNT_TRUE_FULL |
| |
| } // namespace detail |
| |
| // ------------------------------ AllFalse |
| template <class D> |
| HWY_API bool AllFalse(D d, svbool_t m) { |
| return !svptest_any(detail::MakeMask(d), m); |
| } |
| |
| // ------------------------------ AllTrue |
| template <class D> |
| HWY_API bool AllTrue(D d, svbool_t m) { |
| return CountTrue(d, m) == Lanes(d); |
| } |
| |
| // ------------------------------ FindFirstTrue |
| template <class D> |
| HWY_API intptr_t FindFirstTrue(D d, svbool_t m) { |
| return AllFalse(d, m) ? intptr_t{-1} |
| : static_cast<intptr_t>( |
| CountTrue(d, svbrkb_b_z(detail::MakeMask(d), m))); |
| } |
| |
| // ------------------------------ FindKnownFirstTrue |
| template <class D> |
| HWY_API size_t FindKnownFirstTrue(D d, svbool_t m) { |
| return CountTrue(d, svbrkb_b_z(detail::MakeMask(d), m)); |
| } |
| |
| // ------------------------------ IfThenElse |
| #define HWY_SVE_IF_THEN_ELSE(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| NAME(svbool_t m, HWY_SVE_V(BASE, BITS) yes, HWY_SVE_V(BASE, BITS) no) { \ |
| return sv##OP##_##CHAR##BITS(m, yes, no); \ |
| } |
| |
| HWY_SVE_FOREACH(HWY_SVE_IF_THEN_ELSE, IfThenElse, sel) |
| HWY_SVE_FOREACH_BF16(HWY_SVE_IF_THEN_ELSE, IfThenElse, sel) |
| #undef HWY_SVE_IF_THEN_ELSE |
| |
| template <class V, class D = DFromV<V>, HWY_SVE_IF_EMULATED_D(D)> |
| HWY_API V IfThenElse(const svbool_t mask, V yes, V no) { |
| const RebindToUnsigned<D> du; |
| return BitCast( |
| D(), IfThenElse(RebindMask(du, mask), BitCast(du, yes), BitCast(du, no))); |
| } |
| |
| // ------------------------------ IfThenElseZero |
| |
| template <class V, class D = DFromV<V>, HWY_SVE_IF_NOT_EMULATED_D(D)> |
| HWY_API V IfThenElseZero(const svbool_t mask, const V yes) { |
| return IfThenElse(mask, yes, Zero(D())); |
| } |
| |
| template <class V, class D = DFromV<V>, HWY_SVE_IF_EMULATED_D(D)> |
| HWY_API V IfThenElseZero(const svbool_t mask, V yes) { |
| const RebindToUnsigned<D> du; |
| return BitCast(D(), IfThenElseZero(RebindMask(du, mask), BitCast(du, yes))); |
| } |
| |
| // ------------------------------ IfThenZeroElse |
| |
| template <class V, class D = DFromV<V>, HWY_SVE_IF_NOT_EMULATED_D(D)> |
| HWY_API V IfThenZeroElse(const svbool_t mask, const V no) { |
| return IfThenElse(mask, Zero(D()), no); |
| } |
| |
| template <class V, class D = DFromV<V>, HWY_SVE_IF_EMULATED_D(D)> |
| HWY_API V IfThenZeroElse(const svbool_t mask, V no) { |
| const RebindToUnsigned<D> du; |
| return BitCast(D(), IfThenZeroElse(RebindMask(du, mask), BitCast(du, no))); |
| } |
| |
| // ------------------------------ Additional mask logical operations |
| HWY_API svbool_t SetBeforeFirst(svbool_t m) { |
| // We don't know the lane type, so assume 8-bit. For larger types, this will |
| // de-canonicalize the predicate, i.e. set bits to 1 even though they do not |
| // correspond to the lowest byte in the lane. Arm says such bits are ignored. |
| return svbrkb_b_z(HWY_SVE_PTRUE(8), m); |
| } |
| |
| HWY_API svbool_t SetAtOrBeforeFirst(svbool_t m) { |
| // We don't know the lane type, so assume 8-bit. For larger types, this will |
| // de-canonicalize the predicate, i.e. set bits to 1 even though they do not |
| // correspond to the lowest byte in the lane. Arm says such bits are ignored. |
| return svbrka_b_z(HWY_SVE_PTRUE(8), m); |
| } |
| |
| HWY_API svbool_t SetOnlyFirst(svbool_t m) { return svbrka_b_z(m, m); } |
| |
| HWY_API svbool_t SetAtOrAfterFirst(svbool_t m) { |
| return Not(SetBeforeFirst(m)); |
| } |
| |
| // ------------------------------ PromoteMaskTo |
| |
| #ifdef HWY_NATIVE_PROMOTE_MASK_TO |
| #undef HWY_NATIVE_PROMOTE_MASK_TO |
| #else |
| #define HWY_NATIVE_PROMOTE_MASK_TO |
| #endif |
| |
| template <class DTo, class DFrom, |
| HWY_IF_T_SIZE_D(DTo, sizeof(TFromD<DFrom>) * 2)> |
| HWY_API svbool_t PromoteMaskTo(DTo /*d_to*/, DFrom /*d_from*/, svbool_t m) { |
| return svunpklo_b(m); |
| } |
| |
| template <class DTo, class DFrom, |
| HWY_IF_T_SIZE_GT_D(DTo, sizeof(TFromD<DFrom>) * 2)> |
| HWY_API svbool_t PromoteMaskTo(DTo d_to, DFrom d_from, svbool_t m) { |
| using TFrom = TFromD<DFrom>; |
| using TWFrom = MakeWide<MakeUnsigned<TFrom>>; |
| static_assert(sizeof(TWFrom) > sizeof(TFrom), |
| "sizeof(TWFrom) > sizeof(TFrom) must be true"); |
| |
| const Rebind<TWFrom, decltype(d_from)> dw_from; |
| return PromoteMaskTo(d_to, dw_from, PromoteMaskTo(dw_from, d_from, m)); |
| } |
| |
| // ------------------------------ DemoteMaskTo |
| |
| #ifdef HWY_NATIVE_DEMOTE_MASK_TO |
| #undef HWY_NATIVE_DEMOTE_MASK_TO |
| #else |
| #define HWY_NATIVE_DEMOTE_MASK_TO |
| #endif |
| |
| template <class DTo, class DFrom, HWY_IF_T_SIZE_D(DTo, 1), |
| HWY_IF_T_SIZE_D(DFrom, 2)> |
| HWY_API svbool_t DemoteMaskTo(DTo /*d_to*/, DFrom /*d_from*/, svbool_t m) { |
| return svuzp1_b8(m, m); |
| } |
| |
| template <class DTo, class DFrom, HWY_IF_T_SIZE_D(DTo, 2), |
| HWY_IF_T_SIZE_D(DFrom, 4)> |
| HWY_API svbool_t DemoteMaskTo(DTo /*d_to*/, DFrom /*d_from*/, svbool_t m) { |
| return svuzp1_b16(m, m); |
| } |
| |
| template <class DTo, class DFrom, HWY_IF_T_SIZE_D(DTo, 4), |
| HWY_IF_T_SIZE_D(DFrom, 8)> |
| HWY_API svbool_t DemoteMaskTo(DTo /*d_to*/, DFrom /*d_from*/, svbool_t m) { |
| return svuzp1_b32(m, m); |
| } |
| |
| template <class DTo, class DFrom, |
| HWY_IF_T_SIZE_LE_D(DTo, sizeof(TFromD<DFrom>) / 4)> |
| HWY_API svbool_t DemoteMaskTo(DTo d_to, DFrom d_from, svbool_t m) { |
| using TFrom = TFromD<DFrom>; |
| using TNFrom = MakeNarrow<MakeUnsigned<TFrom>>; |
| static_assert(sizeof(TNFrom) < sizeof(TFrom), |
| "sizeof(TNFrom) < sizeof(TFrom) must be true"); |
| |
| const Rebind<TNFrom, decltype(d_from)> dn_from; |
| return DemoteMaskTo(d_to, dn_from, DemoteMaskTo(dn_from, d_from, m)); |
| } |
| |
| // ------------------------------ LowerHalfOfMask |
| #ifdef HWY_NATIVE_LOWER_HALF_OF_MASK |
| #undef HWY_NATIVE_LOWER_HALF_OF_MASK |
| #else |
| #define HWY_NATIVE_LOWER_HALF_OF_MASK |
| #endif |
| |
| template <class D> |
| HWY_API svbool_t LowerHalfOfMask(D /*d*/, svbool_t m) { |
| return m; |
| } |
| |
| // ------------------------------ MaskedAddOr etc. (IfThenElse) |
| |
| #ifdef HWY_NATIVE_MASKED_ARITH |
| #undef HWY_NATIVE_MASKED_ARITH |
| #else |
| #define HWY_NATIVE_MASKED_ARITH |
| #endif |
| |
| namespace detail { |
| HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedMin, min) |
| HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedMax, max) |
| HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedAdd, add) |
| HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedSub, sub) |
| HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedMul, mul) |
| HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGMVV, MaskedDiv, div) |
| HWY_SVE_FOREACH_UI32(HWY_SVE_RETV_ARGMVV, MaskedDiv, div) |
| HWY_SVE_FOREACH_UI64(HWY_SVE_RETV_ARGMVV, MaskedDiv, div) |
| HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGMV, MaskedSqrt, sqrt) |
| #if HWY_SVE_HAVE_2 |
| HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGMVV, MaskedSatAdd, qadd) |
| HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGMVV, MaskedSatSub, qsub) |
| #endif |
| } // namespace detail |
| |
| template <class V, class M> |
| HWY_API V MaskedMinOr(V no, M m, V a, V b) { |
| return IfThenElse(m, detail::MaskedMin(m, a, b), no); |
| } |
| |
| template <class V, class M> |
| HWY_API V MaskedMaxOr(V no, M m, V a, V b) { |
| return IfThenElse(m, detail::MaskedMax(m, a, b), no); |
| } |
| |
| template <class V, class M> |
| HWY_API V MaskedAddOr(V no, M m, V a, V b) { |
| return IfThenElse(m, detail::MaskedAdd(m, a, b), no); |
| } |
| |
| template <class V, class M> |
| HWY_API V MaskedSubOr(V no, M m, V a, V b) { |
| return IfThenElse(m, detail::MaskedSub(m, a, b), no); |
| } |
| |
| template <class V, class M> |
| HWY_API V MaskedMulOr(V no, M m, V a, V b) { |
| return IfThenElse(m, detail::MaskedMul(m, a, b), no); |
| } |
| |
| template <class V, class M, |
| HWY_IF_T_SIZE_ONE_OF_V( |
| V, (hwy::IsSame<TFromV<V>, hwy::float16_t>() ? (1 << 2) : 0) | |
| (1 << 4) | (1 << 8))> |
| HWY_API V MaskedDivOr(V no, M m, V a, V b) { |
| return IfThenElse(m, detail::MaskedDiv(m, a, b), no); |
| } |
| |
| // I8/U8/I16/U16 MaskedDivOr is implemented after I8/U8/I16/U16 Div |
| |
| #if HWY_SVE_HAVE_2 |
| template <class V, class M> |
| HWY_API V MaskedSatAddOr(V no, M m, V a, V b) { |
| return IfThenElse(m, detail::MaskedSatAdd(m, a, b), no); |
| } |
| |
| template <class V, class M> |
| HWY_API V MaskedSatSubOr(V no, M m, V a, V b) { |
| return IfThenElse(m, detail::MaskedSatSub(m, a, b), no); |
| } |
| #else |
| template <class V, class M> |
| HWY_API V MaskedSatAddOr(V no, M m, V a, V b) { |
| return IfThenElse(m, SaturatedAdd(a, b), no); |
| } |
| |
| template <class V, class M> |
| HWY_API V MaskedSatSubOr(V no, M m, V a, V b) { |
| return IfThenElse(m, SaturatedSub(a, b), no); |
| } |
| #endif |
| |
| // ------------------------------ MaskedMulAddOr |
| namespace detail { |
| HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVVV, MaskedMulAdd, mad) |
| } |
| |
| // Per-target flag to prevent generic_ops-inl.h from defining int |
| // MaskedMulAddOr. |
| #ifdef HWY_NATIVE_MASKED_INT_FMA |
| #undef HWY_NATIVE_MASKED_INT_FMA |
| #else |
| #define HWY_NATIVE_MASKED_INT_FMA |
| #endif |
| |
| template <class V, class M> |
| HWY_API V MaskedMulAddOr(V no, M m, V mul, V x, V add) { |
| return IfThenElse(m, detail::MaskedMulAdd(m, mul, x, add), no); |
| } |
| |
| template <class V, HWY_IF_FLOAT_V(V), class M> |
| HWY_API V MaskedSqrtOr(V no, M m, V v) { |
| return IfThenElse(m, detail::MaskedSqrt(m, v), no); |
| } |
| |
| // ================================================== REDUCE |
| |
| #ifdef HWY_NATIVE_REDUCE_SCALAR |
| #undef HWY_NATIVE_REDUCE_SCALAR |
| #else |
| #define HWY_NATIVE_REDUCE_SCALAR |
| #endif |
| |
| // These return T, suitable for ReduceSum. |
| namespace detail { |
| #define HWY_SVE_REDUCE_ADD(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API HWY_SVE_T(BASE, BITS) NAME(svbool_t pg, HWY_SVE_V(BASE, BITS) v) { \ |
| /* The intrinsic returns [u]int64_t; truncate to T so we can broadcast. */ \ |
| using T = HWY_SVE_T(BASE, BITS); \ |
| using TU = MakeUnsigned<T>; \ |
| constexpr uint64_t kMask = LimitsMax<TU>(); \ |
| return static_cast<T>(static_cast<TU>( \ |
| static_cast<uint64_t>(sv##OP##_##CHAR##BITS(pg, v)) & kMask)); \ |
| } |
| |
| #define HWY_SVE_REDUCE(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API HWY_SVE_T(BASE, BITS) NAME(svbool_t pg, HWY_SVE_V(BASE, BITS) v) { \ |
| return sv##OP##_##CHAR##BITS(pg, v); \ |
| } |
| |
| // TODO: Remove SumOfLanesM in favor of using MaskedReduceSum |
| HWY_SVE_FOREACH_UI(HWY_SVE_REDUCE_ADD, SumOfLanesM, addv) |
| HWY_SVE_FOREACH_F(HWY_SVE_REDUCE, SumOfLanesM, addv) |
| |
| HWY_SVE_FOREACH_UI(HWY_SVE_REDUCE, MinOfLanesM, minv) |
| HWY_SVE_FOREACH_UI(HWY_SVE_REDUCE, MaxOfLanesM, maxv) |
| // NaN if all are |
| HWY_SVE_FOREACH_F(HWY_SVE_REDUCE, MinOfLanesM, minnmv) |
| HWY_SVE_FOREACH_F(HWY_SVE_REDUCE, MaxOfLanesM, maxnmv) |
| |
| #undef HWY_SVE_REDUCE |
| #undef HWY_SVE_REDUCE_ADD |
| } // namespace detail |
| |
| // detail::SumOfLanesM, detail::MinOfLanesM, and detail::MaxOfLanesM is more |
| // efficient for N=4 I8/U8 reductions on SVE than the default implementations |
| // of the N=4 I8/U8 ReduceSum/ReduceMin/ReduceMax operations in |
| // generic_ops-inl.h |
| #undef HWY_IF_REDUCE_D |
| #define HWY_IF_REDUCE_D(D) hwy::EnableIf<HWY_MAX_LANES_D(D) != 1>* = nullptr |
| |
| #ifdef HWY_NATIVE_REDUCE_SUM_4_UI8 |
| #undef HWY_NATIVE_REDUCE_SUM_4_UI8 |
| #else |
| #define HWY_NATIVE_REDUCE_SUM_4_UI8 |
| #endif |
| |
| #ifdef HWY_NATIVE_REDUCE_MINMAX_4_UI8 |
| #undef HWY_NATIVE_REDUCE_MINMAX_4_UI8 |
| #else |
| #define HWY_NATIVE_REDUCE_MINMAX_4_UI8 |
| #endif |
| |
| template <class D, HWY_IF_REDUCE_D(D)> |
| HWY_API TFromD<D> ReduceSum(D d, VFromD<D> v) { |
| return detail::SumOfLanesM(detail::MakeMask(d), v); |
| } |
| |
| template <class D, HWY_IF_REDUCE_D(D)> |
| HWY_API TFromD<D> ReduceMin(D d, VFromD<D> v) { |
| return detail::MinOfLanesM(detail::MakeMask(d), v); |
| } |
| |
| template <class D, HWY_IF_REDUCE_D(D)> |
| HWY_API TFromD<D> ReduceMax(D d, VFromD<D> v) { |
| return detail::MaxOfLanesM(detail::MakeMask(d), v); |
| } |
| |
| #ifdef HWY_NATIVE_MASKED_REDUCE_SCALAR |
| #undef HWY_NATIVE_MASKED_REDUCE_SCALAR |
| #else |
| #define HWY_NATIVE_MASKED_REDUCE_SCALAR |
| #endif |
| |
| template <class D, class M> |
| HWY_API TFromD<D> MaskedReduceSum(D /*d*/, M m, VFromD<D> v) { |
| return detail::SumOfLanesM(m, v); |
| } |
| template <class D, class M> |
| HWY_API TFromD<D> MaskedReduceMin(D /*d*/, M m, VFromD<D> v) { |
| return detail::MinOfLanesM(m, v); |
| } |
| template <class D, class M> |
| HWY_API TFromD<D> MaskedReduceMax(D /*d*/, M m, VFromD<D> v) { |
| return detail::MaxOfLanesM(m, v); |
| } |
| |
| // ------------------------------ SumOfLanes |
| |
| template <class D, HWY_IF_LANES_GT_D(D, 1)> |
| HWY_API VFromD<D> SumOfLanes(D d, VFromD<D> v) { |
| return Set(d, ReduceSum(d, v)); |
| } |
| template <class D, HWY_IF_LANES_GT_D(D, 1)> |
| HWY_API VFromD<D> MinOfLanes(D d, VFromD<D> v) { |
| return Set(d, ReduceMin(d, v)); |
| } |
| template <class D, HWY_IF_LANES_GT_D(D, 1)> |
| HWY_API VFromD<D> MaxOfLanes(D d, VFromD<D> v) { |
| return Set(d, ReduceMax(d, v)); |
| } |
| |
| // ------------------------------ MaskedAdd etc. (IfThenElse) |
| |
| #ifdef HWY_NATIVE_ZERO_MASKED_ARITH |
| #undef HWY_NATIVE_ZERO_MASKED_ARITH |
| #else |
| #define HWY_NATIVE_ZERO_MASKED_ARITH |
| #endif |
| |
| HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV_Z, MaskedMax, max) |
| HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV_Z, MaskedAdd, add) |
| HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV_Z, MaskedSub, sub) |
| HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV_Z, MaskedMul, mul) |
| HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGMVV_Z, MaskedDiv, div) |
| HWY_SVE_FOREACH_UI32(HWY_SVE_RETV_ARGMVV_Z, MaskedDiv, div) |
| HWY_SVE_FOREACH_UI64(HWY_SVE_RETV_ARGMVV_Z, MaskedDiv, div) |
| HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVVV_Z, MaskedMulAdd, mad) |
| HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVVV_Z, MaskedNegMulAdd, msb) |
| |
| // I8/U8/I16/U16 MaskedDiv is implemented after I8/U8/I16/U16 Div |
| |
| #if HWY_SVE_HAVE_2 |
| HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGMVV_Z, MaskedSaturatedAdd, qadd) |
| HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGMVV_Z, MaskedSaturatedSub, qsub) |
| #else |
| template <class V, class M> |
| HWY_API V MaskedSaturatedAdd(M m, V a, V b) { |
| return IfThenElseZero(m, SaturatedAdd(a, b)); |
| } |
| |
| template <class V, class M> |
| HWY_API V MaskedSaturatedSub(M m, V a, V b) { |
| return IfThenElseZero(m, SaturatedSub(a, b)); |
| } |
| #endif |
| |
| template <class V, class M, typename D = DFromV<V>, HWY_IF_I16_D(D)> |
| HWY_API V MaskedMulFixedPoint15(M m, V a, V b) { |
| return IfThenElseZero(m, MulFixedPoint15(a, b)); |
| } |
| |
| template <class D, class M, HWY_IF_UI32_D(D), |
| class V16 = VFromD<RepartitionToNarrow<D>>> |
| HWY_API VFromD<D> MaskedWidenMulPairwiseAdd(D d32, M m, V16 a, V16 b) { |
| return IfThenElseZero(m, WidenMulPairwiseAdd(d32, a, b)); |
| } |
| |
| template <class DF, class M, HWY_IF_F32_D(DF), class VBF> |
| HWY_API VFromD<DF> MaskedWidenMulPairwiseAdd(DF df, M m, VBF a, VBF b) { |
| return IfThenElseZero(m, WidenMulPairwiseAdd(df, a, b)); |
| } |
| |
| // ================================================== COMPARE |
| |
| // mask = f(vector, vector) |
| #define HWY_SVE_COMPARE(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API svbool_t NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ |
| return sv##OP##_##CHAR##BITS(HWY_SVE_PTRUE(BITS), a, b); \ |
| } |
| #define HWY_SVE_COMPARE_N(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API svbool_t NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \ |
| return sv##OP##_##CHAR##BITS(HWY_SVE_PTRUE(BITS), a, b); \ |
| } |
| |
| // ------------------------------ Eq |
| HWY_SVE_FOREACH(HWY_SVE_COMPARE, Eq, cmpeq) |
| namespace detail { |
| HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, EqN, cmpeq_n) |
| } // namespace detail |
| |
| // ------------------------------ Ne |
| HWY_SVE_FOREACH(HWY_SVE_COMPARE, Ne, cmpne) |
| namespace detail { |
| HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, NeN, cmpne_n) |
| } // namespace detail |
| |
| // ------------------------------ Lt |
| HWY_SVE_FOREACH(HWY_SVE_COMPARE, Lt, cmplt) |
| namespace detail { |
| HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, LtN, cmplt_n) |
| } // namespace detail |
| |
| // ------------------------------ Le |
| HWY_SVE_FOREACH(HWY_SVE_COMPARE, Le, cmple) |
| namespace detail { |
| HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, LeN, cmple_n) |
| } // namespace detail |
| |
| // ------------------------------ Gt/Ge (swapped order) |
| template <class V> |
| HWY_API svbool_t Gt(const V a, const V b) { |
| return Lt(b, a); |
| } |
| template <class V> |
| HWY_API svbool_t Ge(const V a, const V b) { |
| return Le(b, a); |
| } |
| namespace detail { |
| HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, GeN, cmpge_n) |
| HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, GtN, cmpgt_n) |
| } // namespace detail |
| |
| #undef HWY_SVE_COMPARE |
| #undef HWY_SVE_COMPARE_N |
| |
| // ------------------------------ TestBit |
| template <class V> |
| HWY_API svbool_t TestBit(const V a, const V bit) { |
| return detail::NeN(And(a, bit), 0); |
| } |
| |
| // ------------------------------ Min/Max (Lt, IfThenElse) |
| |
| HWY_SVE_FOREACH_U(HWY_SVE_RETV_ARGPVV, Min, min) |
| HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Max, max) |
| HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPVV, Max, maxnm) |
| |
| // Workaround for incorrect results with `svmin`. |
| #if HWY_COMPILER_CLANG |
| template <class V, HWY_IF_SIGNED_V(V)> |
| HWY_API V Min(V a, V b) { |
| return IfThenElse(Lt(a, b), a, b); |
| } |
| template <class V, HWY_IF_FLOAT_OR_SPECIAL_V(V)> |
| HWY_API V Min(V a, V b) { |
| return IfThenElse(Lt(a, b), a, b); |
| } |
| |
| #else |
| HWY_SVE_FOREACH_I(HWY_SVE_RETV_ARGPVV, Min, min) |
| HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPVV, Min, minnm) |
| #endif |
| |
| namespace detail { |
| HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, MinN, min_n) |
| HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, MaxN, max_n) |
| } // namespace detail |
| |
| // ================================================== SWIZZLE |
| |
| // ------------------------------ ConcatEven/ConcatOdd |
| |
| // WARNING: the upper half of these needs fixing up (uzp1/uzp2 use the |
| // full vector length, not rounded down to a power of two as we require). |
| namespace detail { |
| |
| #define HWY_SVE_CONCAT_EVERY_SECOND(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_INLINE HWY_SVE_V(BASE, BITS) \ |
| NAME(HWY_SVE_V(BASE, BITS) hi, HWY_SVE_V(BASE, BITS) lo) { \ |
| return sv##OP##_##CHAR##BITS(lo, hi); \ |
| } |
| HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatEvenFull, uzp1) |
| HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatOddFull, uzp2) |
| #if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC |
| HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_CONCAT_EVERY_SECOND, ConcatEvenFull, |
| uzp1) |
| HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_CONCAT_EVERY_SECOND, ConcatOddFull, |
| uzp2) |
| #endif // HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC |
| #if defined(__ARM_FEATURE_SVE_MATMUL_FP64) |
| HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatEvenBlocks, uzp1q) |
| HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatOddBlocks, uzp2q) |
| #if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC |
| HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_CONCAT_EVERY_SECOND, |
| ConcatEvenBlocks, uzp1q) |
| HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_CONCAT_EVERY_SECOND, ConcatOddBlocks, |
| uzp2q) |
| #endif // HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC |
| #endif // defined(__ARM_FEATURE_SVE_MATMUL_FP64) |
| #undef HWY_SVE_CONCAT_EVERY_SECOND |
| |
| // Used to slide up / shift whole register left; mask indicates which range |
| // to take from lo, and the rest is filled from hi starting at its lowest. |
| #define HWY_SVE_SPLICE(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API HWY_SVE_V(BASE, BITS) NAME( \ |
| HWY_SVE_V(BASE, BITS) hi, HWY_SVE_V(BASE, BITS) lo, svbool_t mask) { \ |
| return sv##OP##_##CHAR##BITS(mask, lo, hi); \ |
| } |
| HWY_SVE_FOREACH(HWY_SVE_SPLICE, Splice, splice) |
| #if HWY_SVE_HAVE_BF16_FEATURE |
| HWY_SVE_FOREACH_BF16(HWY_SVE_SPLICE, Splice, splice) |
| #else |
| template <class V, HWY_IF_BF16_D(DFromV<V>)> |
| HWY_INLINE V Splice(V hi, V lo, svbool_t mask) { |
| const DFromV<V> d; |
| const RebindToUnsigned<decltype(d)> du; |
| return BitCast(d, Splice(BitCast(du, hi), BitCast(du, lo), mask)); |
| } |
| #endif // HWY_SVE_HAVE_BF16_FEATURE |
| #undef HWY_SVE_SPLICE |
| |
| } // namespace detail |
| |
| template <class D> |
| HWY_API VFromD<D> ConcatOdd(D d, VFromD<D> hi, VFromD<D> lo) { |
| #if HWY_SVE_IS_POW2 |
| if (detail::IsFull(d)) return detail::ConcatOddFull(hi, lo); |
| #endif |
| const VFromD<D> hi_odd = detail::ConcatOddFull(hi, hi); |
| const VFromD<D> lo_odd = detail::ConcatOddFull(lo, lo); |
| return detail::Splice(hi_odd, lo_odd, FirstN(d, Lanes(d) / 2)); |
| } |
| |
| template <class D> |
| HWY_API VFromD<D> ConcatEven(D d, VFromD<D> hi, VFromD<D> lo) { |
| #if HWY_SVE_IS_POW2 |
| if (detail::IsFull(d)) return detail::ConcatEvenFull(hi, lo); |
| #endif |
| const VFromD<D> hi_odd = detail::ConcatEvenFull(hi, hi); |
| const VFromD<D> lo_odd = detail::ConcatEvenFull(lo, lo); |
| return detail::Splice(hi_odd, lo_odd, FirstN(d, Lanes(d) / 2)); |
| } |
| |
| HWY_API svuint8_t U8FromU32(const svuint32_t v) { |
| const DFromV<svuint32_t> du32; |
| const RepartitionToNarrow<decltype(du32)> du16; |
| const RepartitionToNarrow<decltype(du16)> du8; |
| |
| const svuint16_t cast16 = BitCast(du16, v); |
| const svuint16_t x2 = svuzp1_u16(cast16, cast16); |
| const svuint8_t cast8 = BitCast(du8, x2); |
| return svuzp1_u8(cast8, cast8); |
| } |
| |
| // ================================================== MASK |
| |
| // ------------------------------ MaskFromVec (Ne) |
| template <class V> |
| HWY_API svbool_t MaskFromVec(const V v) { |
| using T = TFromV<V>; |
| return detail::NeN(v, ConvertScalarTo<T>(0)); |
| } |
| |
| // ------------------------------ VecFromMask |
| template <class D> |
| HWY_API VFromD<D> VecFromMask(const D d, svbool_t mask) { |
| const RebindToSigned<D> di; |
| // This generates MOV imm, whereas svdup_n_s8_z generates MOV scalar, which |
| // requires an extra instruction plus M0 pipeline. |
| return BitCast(d, IfThenElseZero(mask, Set(di, -1))); |
| } |
| |
| // ------------------------------ BitsFromMask (AndN, Shl, ReduceSum, GetLane |
| // ConcatEvenFull, U8FromU32) |
| |
| namespace detail { |
| |
| // For each mask lane (governing lane type T), store 1 or 0 in BYTE lanes. |
| template <class D, HWY_IF_T_SIZE_D(D, 1)> |
| HWY_INLINE svuint8_t BoolFromMask(svbool_t m) { |
| return svdup_n_u8_z(m, 1); |
| } |
| template <class D, HWY_IF_T_SIZE_D(D, 2)> |
| HWY_INLINE svuint8_t BoolFromMask(svbool_t m) { |
| const ScalableTag<uint8_t> d8; |
| const svuint8_t b16 = BitCast(d8, svdup_n_u16_z(m, 1)); |
| return detail::ConcatEvenFull(b16, b16); // lower half |
| } |
| template <class D, HWY_IF_T_SIZE_D(D, 4)> |
| HWY_INLINE svuint8_t BoolFromMask(svbool_t m) { |
| return U8FromU32(svdup_n_u32_z(m, 1)); |
| } |
| template <class D, HWY_IF_T_SIZE_D(D, 8)> |
| HWY_INLINE svuint8_t BoolFromMask(svbool_t m) { |
| const ScalableTag<uint32_t> d32; |
| const svuint32_t b64 = BitCast(d32, svdup_n_u64_z(m, 1)); |
| return U8FromU32(detail::ConcatEvenFull(b64, b64)); // lower half |
| } |
| |
| // Compacts groups of 8 u8 into 8 contiguous bits in a 64-bit lane. |
| HWY_INLINE svuint64_t BitsFromBool(svuint8_t x) { |
| const ScalableTag<uint8_t> d8; |
| const ScalableTag<uint16_t> d16; |
| const ScalableTag<uint32_t> d32; |
| const ScalableTag<uint64_t> d64; |
| // TODO(janwas): could use SVE2 BDEP, but it's optional. |
| x = Or(x, BitCast(d8, ShiftRight<7>(BitCast(d16, x)))); |
| x = Or(x, BitCast(d8, ShiftRight<14>(BitCast(d32, x)))); |
| x = Or(x, BitCast(d8, ShiftRight<28>(BitCast(d64, x)))); |
| return BitCast(d64, x); |
| } |
| |
| } // namespace detail |
| |
| // BitsFromMask is required if `HWY_MAX_BYTES <= 64`, which is true for the |
| // fixed-size SVE targets. |
| #if HWY_TARGET == HWY_SVE2_128 || HWY_TARGET == HWY_SVE_256 |
| template <class D> |
| HWY_API uint64_t BitsFromMask(D d, svbool_t mask) { |
| const Repartition<uint64_t, D> du64; |
| svuint64_t bits_in_u64 = detail::BitsFromBool(detail::BoolFromMask<D>(mask)); |
| |
| constexpr size_t N = MaxLanes(d); |
| static_assert(N < 64, "SVE2_128 and SVE_256 are only 128 or 256 bits"); |
| const uint64_t valid = (1ull << N) - 1; |
| HWY_IF_CONSTEXPR(N <= 8) { |
| // Upper bits are undefined even if N == 8, hence mask. |
| return GetLane(bits_in_u64) & valid; |
| } |
| |
| // Up to 8 of the least-significant bits of each u64 lane are valid. |
| bits_in_u64 = detail::AndN(bits_in_u64, 0xFF); |
| |
| // 128-bit vector: only two u64, so avoid ReduceSum. |
| HWY_IF_CONSTEXPR(HWY_TARGET == HWY_SVE2_128) { |
| alignas(16) uint64_t lanes[2]; |
| Store(bits_in_u64, du64, lanes); |
| // lanes[0] is always valid because we know N > 8, but lanes[1] might |
| // not be - we may mask it out below. |
| const uint64_t result = lanes[0] + (lanes[1] << 8); |
| // 8-bit lanes, no further masking |
| HWY_IF_CONSTEXPR(N == 16) return result; |
| return result & valid; |
| } |
| |
| // Shift the 8-bit groups into place in each u64 lane. |
| alignas(32) uint64_t kShifts[4] = {0 * 8, 1 * 8, 2 * 8, 3 * 8}; |
| bits_in_u64 = Shl(bits_in_u64, Load(du64, kShifts)); |
| return ReduceSum(du64, bits_in_u64) & valid; |
| } |
| |
| #endif // HWY_TARGET == HWY_SVE2_128 || HWY_TARGET == HWY_SVE_256 |
| |
| // ------------------------------ IsNegative (Lt) |
| #ifdef HWY_NATIVE_IS_NEGATIVE |
| #undef HWY_NATIVE_IS_NEGATIVE |
| #else |
| #define HWY_NATIVE_IS_NEGATIVE |
| #endif |
| |
| template <class V, HWY_IF_NOT_UNSIGNED_V(V)> |
| HWY_API svbool_t IsNegative(V v) { |
| const DFromV<decltype(v)> d; |
| const RebindToSigned<decltype(d)> di; |
| using TI = TFromD<decltype(di)>; |
| |
| return detail::LtN(BitCast(di, v), static_cast<TI>(0)); |
| } |
| |
| // ------------------------------ IfVecThenElse (MaskFromVec, IfThenElse) |
| |
| #if HWY_SVE_HAVE_2 |
| |
| #define HWY_SVE_IF_VEC(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| NAME(HWY_SVE_V(BASE, BITS) mask, HWY_SVE_V(BASE, BITS) yes, \ |
| HWY_SVE_V(BASE, BITS) no) { \ |
| return sv##OP##_##CHAR##BITS(yes, no, mask); \ |
| } |
| |
| HWY_SVE_FOREACH_UI(HWY_SVE_IF_VEC, IfVecThenElse, bsl) |
| #undef HWY_SVE_IF_VEC |
| |
| template <class V, HWY_IF_FLOAT_V(V)> |
| HWY_API V IfVecThenElse(const V mask, const V yes, const V no) { |
| const DFromV<V> d; |
| const RebindToUnsigned<decltype(d)> du; |
| return BitCast( |
| d, IfVecThenElse(BitCast(du, mask), BitCast(du, yes), BitCast(du, no))); |
| } |
| |
| #else |
| |
| template <class V> |
| HWY_API V IfVecThenElse(const V mask, const V yes, const V no) { |
| return Or(And(mask, yes), AndNot(mask, no)); |
| } |
| |
| #endif // HWY_SVE_HAVE_2 |
| |
| // ------------------------------ BitwiseIfThenElse |
| |
| #ifdef HWY_NATIVE_BITWISE_IF_THEN_ELSE |
| #undef HWY_NATIVE_BITWISE_IF_THEN_ELSE |
| #else |
| #define HWY_NATIVE_BITWISE_IF_THEN_ELSE |
| #endif |
| |
| template <class V> |
| HWY_API V BitwiseIfThenElse(V mask, V yes, V no) { |
| return IfVecThenElse(mask, yes, no); |
| } |
| |
| // ------------------------------ CopySign (BitwiseIfThenElse) |
| template <class V> |
| HWY_API V CopySign(const V magn, const V sign) { |
| const DFromV<decltype(magn)> d; |
| return BitwiseIfThenElse(SignBit(d), sign, magn); |
| } |
| |
| // ------------------------------ CopySignToAbs |
| template <class V> |
| HWY_API V CopySignToAbs(const V abs, const V sign) { |
| #if HWY_SVE_HAVE_2 // CopySign is more efficient than OrAnd |
| return CopySign(abs, sign); |
| #else |
| const DFromV<V> d; |
| return OrAnd(abs, SignBit(d), sign); |
| #endif |
| } |
| |
| // ------------------------------ Floating-point classification (Ne) |
| |
| template <class V> |
| HWY_API svbool_t IsNaN(const V v) { |
| return Ne(v, v); // could also use cmpuo |
| } |
| |
| // Per-target flag to prevent generic_ops-inl.h from defining IsInf / IsFinite. |
| // We use a fused Set/comparison for IsFinite. |
| #ifdef HWY_NATIVE_ISINF |
| #undef HWY_NATIVE_ISINF |
| #else |
| #define HWY_NATIVE_ISINF |
| #endif |
| |
| template <class V> |
| HWY_API svbool_t IsInf(const V v) { |
| using T = TFromV<V>; |
| const DFromV<decltype(v)> d; |
| const RebindToUnsigned<decltype(d)> du; |
| const RebindToSigned<decltype(d)> di; |
| |
| // 'Shift left' to clear the sign bit |
| const VFromD<decltype(du)> vu = BitCast(du, v); |
| const VFromD<decltype(du)> v2 = Add(vu, vu); |
| // Check for exponent=max and mantissa=0. |
| const VFromD<decltype(di)> max2 = Set(di, hwy::MaxExponentTimes2<T>()); |
| return RebindMask(d, Eq(v2, BitCast(du, max2))); |
| } |
| |
| // Returns whether normal/subnormal/zero. |
| template <class V> |
| HWY_API svbool_t IsFinite(const V v) { |
| using T = TFromV<V>; |
| const DFromV<decltype(v)> d; |
| const RebindToUnsigned<decltype(d)> du; |
| const RebindToSigned<decltype(d)> di; // cheaper than unsigned comparison |
| const VFromD<decltype(du)> vu = BitCast(du, v); |
| // 'Shift left' to clear the sign bit, then right so we can compare with the |
| // max exponent (cannot compare with MaxExponentTimes2 directly because it is |
| // negative and non-negative floats would be greater). |
| const VFromD<decltype(di)> exp = |
| BitCast(di, ShiftRight<hwy::MantissaBits<T>() + 1>(Add(vu, vu))); |
| return RebindMask(d, detail::LtN(exp, hwy::MaxExponentField<T>())); |
| } |
| |
| // ------------------------------ MulByPow2/MulByFloorPow2 |
| |
| #define HWY_SVE_MUL_BY_POW2(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_V(int, BITS) exp) { \ |
| return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v, exp); \ |
| } |
| |
| HWY_SVE_FOREACH_F(HWY_SVE_MUL_BY_POW2, MulByPow2, scale) |
| |
| #undef HWY_SVE_MUL_BY_POW2 |
| |
| // ------------------------------ MaskedEq etc. |
| #ifdef HWY_NATIVE_MASKED_COMP |
| #undef HWY_NATIVE_MASKED_COMP |
| #else |
| #define HWY_NATIVE_MASKED_COMP |
| #endif |
| |
| // mask = f(mask, vector, vector) |
| #define HWY_SVE_COMPARE_Z(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API svbool_t NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, \ |
| HWY_SVE_V(BASE, BITS) b) { \ |
| return sv##OP##_##CHAR##BITS(m, a, b); \ |
| } |
| |
| HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedEq, cmpeq) |
| HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedNe, cmpne) |
| HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedLt, cmplt) |
| HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedLe, cmple) |
| |
| #undef HWY_SVE_COMPARE_Z |
| |
| template <class V, class M, class D = DFromV<V>> |
| HWY_API MFromD<D> MaskedGt(M m, V a, V b) { |
| // Swap args to reverse comparison |
| return MaskedLt(m, b, a); |
| } |
| |
| template <class V, class M, class D = DFromV<V>> |
| HWY_API MFromD<D> MaskedGe(M m, V a, V b) { |
| // Swap args to reverse comparison |
| return MaskedLe(m, b, a); |
| } |
| |
| template <class V, class M, class D = DFromV<V>> |
| HWY_API MFromD<D> MaskedIsNaN(const M m, const V v) { |
| return MaskedNe(m, v, v); |
| } |
| |
| // ================================================== MEMORY |
| |
| // ------------------------------ LoadU/MaskedLoad/LoadDup128/StoreU/Stream |
| |
| #define HWY_SVE_MEM(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| template <size_t N, int kPow2> \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| LoadU(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ |
| const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ |
| return svld1_##CHAR##BITS(detail::MakeMask(d), \ |
| detail::NativeLanePointer(p)); \ |
| } \ |
| template <size_t N, int kPow2> \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| MaskedLoad(svbool_t m, HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ |
| const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ |
| return svld1_##CHAR##BITS(m, detail::NativeLanePointer(p)); \ |
| } \ |
| template <size_t N, int kPow2> \ |
| HWY_API void StoreU(HWY_SVE_V(BASE, BITS) v, \ |
| HWY_SVE_D(BASE, BITS, N, kPow2) d, \ |
| HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ |
| svst1_##CHAR##BITS(detail::MakeMask(d), detail::NativeLanePointer(p), v); \ |
| } \ |
| template <size_t N, int kPow2> \ |
| HWY_API void Stream(HWY_SVE_V(BASE, BITS) v, \ |
| HWY_SVE_D(BASE, BITS, N, kPow2) d, \ |
| HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ |
| svstnt1_##CHAR##BITS(detail::MakeMask(d), detail::NativeLanePointer(p), \ |
| v); \ |
| } \ |
| template <size_t N, int kPow2> \ |
| HWY_API void BlendedStore(HWY_SVE_V(BASE, BITS) v, svbool_t m, \ |
| HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ |
| HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ |
| svst1_##CHAR##BITS(m, detail::NativeLanePointer(p), v); \ |
| } |
| |
| HWY_SVE_FOREACH(HWY_SVE_MEM, _, _) |
| HWY_SVE_FOREACH_BF16(HWY_SVE_MEM, _, _) |
| |
| template <class D, HWY_SVE_IF_EMULATED_D(D)> |
| HWY_API VFromD<D> LoadU(D d, const TFromD<D>* HWY_RESTRICT p) { |
| const RebindToUnsigned<decltype(d)> du; |
| return BitCast(d, LoadU(du, detail::U16LanePointer(p))); |
| } |
| |
| template <class D, HWY_SVE_IF_EMULATED_D(D)> |
| HWY_API void StoreU(VFromD<D> v, D d, TFromD<D>* HWY_RESTRICT p) { |
| const RebindToUnsigned<decltype(d)> du; |
| StoreU(BitCast(du, v), du, detail::U16LanePointer(p)); |
| } |
| |
| template <class D, HWY_SVE_IF_EMULATED_D(D)> |
| HWY_API VFromD<D> MaskedLoad(MFromD<D> m, D d, |
| const TFromD<D>* HWY_RESTRICT p) { |
| const RebindToUnsigned<decltype(d)> du; |
| return BitCast(d, |
| MaskedLoad(RebindMask(du, m), du, detail::U16LanePointer(p))); |
| } |
| |
| // MaskedLoadOr is generic and does not require emulation. |
| |
| template <class D, HWY_SVE_IF_EMULATED_D(D)> |
| HWY_API void BlendedStore(VFromD<D> v, MFromD<D> m, D d, |
| TFromD<D>* HWY_RESTRICT p) { |
| const RebindToUnsigned<decltype(d)> du; |
| BlendedStore(BitCast(du, v), RebindMask(du, m), du, |
| detail::U16LanePointer(p)); |
| } |
| |
| #undef HWY_SVE_MEM |
| |
| #if HWY_TARGET != HWY_SVE2_128 |
| namespace detail { |
| #define HWY_SVE_LOAD_DUP128(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| template <size_t N, int kPow2> \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ |
| const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ |
| /* All-true predicate to load all 128 bits. */ \ |
| return sv##OP##_##CHAR##BITS(HWY_SVE_PTRUE(8), \ |
| detail::NativeLanePointer(p)); \ |
| } |
| |
| HWY_SVE_FOREACH(HWY_SVE_LOAD_DUP128, LoadDupFull128, ld1rq) |
| HWY_SVE_FOREACH_BF16(HWY_SVE_LOAD_DUP128, LoadDupFull128, ld1rq) |
| |
| template <class D, HWY_SVE_IF_EMULATED_D(D)> |
| HWY_API VFromD<D> LoadDupFull128(D d, const TFromD<D>* HWY_RESTRICT p) { |
| const RebindToUnsigned<decltype(d)> du; |
| return BitCast(d, LoadDupFull128(du, detail::U16LanePointer(p))); |
| } |
| |
| } // namespace detail |
| #endif // HWY_TARGET != HWY_SVE2_128 |
| |
| #if HWY_TARGET == HWY_SVE2_128 |
| // On the HWY_SVE2_128 target, LoadDup128 is the same as LoadU since vectors |
| // cannot exceed 16 bytes on the HWY_SVE2_128 target. |
| template <class D> |
| HWY_API VFromD<D> LoadDup128(D d, const TFromD<D>* HWY_RESTRICT p) { |
| return LoadU(d, p); |
| } |
| #else // HWY_TARGET != HWY_SVE2_128 |
| // If D().MaxBytes() <= 16 is true, simply do a LoadU operation. |
| template <class D, HWY_IF_V_SIZE_LE_D(D, 16)> |
| HWY_API VFromD<D> LoadDup128(D d, const TFromD<D>* HWY_RESTRICT p) { |
| return LoadU(d, p); |
| } |
| |
| // If D().MaxBytes() > 16 is true, need to load the vector using ld1rq |
| template <class D, HWY_IF_V_SIZE_GT_D(D, 16)> |
| HWY_API VFromD<D> LoadDup128(D d, const TFromD<D>* HWY_RESTRICT p) { |
| return detail::LoadDupFull128(d, p); |
| } |
| |
| #endif // HWY_TARGET != HWY_SVE2_128 |
| |
| // Truncate to smaller size and store |
| #ifdef HWY_NATIVE_STORE_TRUNCATED |
| #undef HWY_NATIVE_STORE_TRUNCATED |
| #else |
| #define HWY_NATIVE_STORE_TRUNCATED |
| #endif |
| |
| #define HWY_SVE_STORE_TRUNCATED(BASE, CHAR, BITS, HALF, NAME, OP, TO_BITS) \ |
| template <size_t N, int kPow2> \ |
| HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, \ |
| const HWY_SVE_D(BASE, BITS, N, kPow2) d, \ |
| HWY_SVE_T(BASE, TO_BITS) * HWY_RESTRICT p) { \ |
| sv##OP##_##CHAR##BITS(detail::PTrue(d), detail::NativeLanePointer(p), v); \ |
| } |
| |
| #define HWY_SVE_STORE_TRUNCATED_BYTE(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_SVE_STORE_TRUNCATED(BASE, CHAR, BITS, HALF, NAME, OP, 8) |
| #define HWY_SVE_STORE_TRUNCATED_HALF(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_SVE_STORE_TRUNCATED(BASE, CHAR, BITS, HALF, NAME, OP, 16) |
| #define HWY_SVE_STORE_TRUNCATED_WORD(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_SVE_STORE_TRUNCATED(BASE, CHAR, BITS, HALF, NAME, OP, 32) |
| |
| HWY_SVE_FOREACH_UI16(HWY_SVE_STORE_TRUNCATED_BYTE, TruncateStore, st1b) |
| HWY_SVE_FOREACH_UI32(HWY_SVE_STORE_TRUNCATED_BYTE, TruncateStore, st1b) |
| HWY_SVE_FOREACH_UI64(HWY_SVE_STORE_TRUNCATED_BYTE, TruncateStore, st1b) |
| HWY_SVE_FOREACH_UI32(HWY_SVE_STORE_TRUNCATED_HALF, TruncateStore, st1h) |
| HWY_SVE_FOREACH_UI64(HWY_SVE_STORE_TRUNCATED_HALF, TruncateStore, st1h) |
| HWY_SVE_FOREACH_UI64(HWY_SVE_STORE_TRUNCATED_WORD, TruncateStore, st1w) |
| |
| #undef HWY_SVE_STORE_TRUNCATED |
| |
| // ------------------------------ Load/Store |
| |
| // SVE only requires lane alignment, not natural alignment of the entire |
| // vector, so Load/Store are the same as LoadU/StoreU. |
| template <class D> |
| HWY_API VFromD<D> Load(D d, const TFromD<D>* HWY_RESTRICT p) { |
| return LoadU(d, p); |
| } |
| |
| template <class V, class D> |
| HWY_API void Store(const V v, D d, TFromD<D>* HWY_RESTRICT p) { |
| StoreU(v, d, p); |
| } |
| |
| // ------------------------------ MaskedLoadOr |
| |
| // SVE MaskedLoad hard-codes zero, so this requires an extra blend. |
| template <class D> |
| HWY_API VFromD<D> MaskedLoadOr(VFromD<D> v, MFromD<D> m, D d, |
| const TFromD<D>* HWY_RESTRICT p) { |
| return IfThenElse(m, MaskedLoad(m, d, p), v); |
| } |
| |
| // ------------------------------ ScatterOffset/Index |
| |
| #ifdef HWY_NATIVE_SCATTER |
| #undef HWY_NATIVE_SCATTER |
| #else |
| #define HWY_NATIVE_SCATTER |
| #endif |
| |
| #define HWY_SVE_SCATTER_OFFSET(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| template <size_t N, int kPow2> \ |
| HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, \ |
| HWY_SVE_D(BASE, BITS, N, kPow2) d, \ |
| HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, \ |
| HWY_SVE_V(int, BITS) offset) { \ |
| sv##OP##_s##BITS##offset_##CHAR##BITS(detail::MakeMask(d), base, offset, \ |
| v); \ |
| } |
| |
| #define HWY_SVE_MASKED_SCATTER_INDEX(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| template <size_t N, int kPow2> \ |
| HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, svbool_t m, \ |
| HWY_SVE_D(BASE, BITS, N, kPow2) /*d*/, \ |
| HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, \ |
| HWY_SVE_V(int, BITS) indices) { \ |
| sv##OP##_s##BITS##index_##CHAR##BITS(m, base, indices, v); \ |
| } |
| |
| HWY_SVE_FOREACH_UIF3264(HWY_SVE_SCATTER_OFFSET, ScatterOffset, st1_scatter) |
| HWY_SVE_FOREACH_UIF3264(HWY_SVE_MASKED_SCATTER_INDEX, MaskedScatterIndex, |
| st1_scatter) |
| #undef HWY_SVE_SCATTER_OFFSET |
| #undef HWY_SVE_MASKED_SCATTER_INDEX |
| |
| template <class D> |
| HWY_API void ScatterIndex(VFromD<D> v, D d, TFromD<D>* HWY_RESTRICT p, |
| VFromD<RebindToSigned<D>> indices) { |
| MaskedScatterIndex(v, detail::MakeMask(d), d, p, indices); |
| } |
| |
| // ------------------------------ GatherOffset/Index |
| |
| #ifdef HWY_NATIVE_GATHER |
| #undef HWY_NATIVE_GATHER |
| #else |
| #define HWY_NATIVE_GATHER |
| #endif |
| |
| #define HWY_SVE_GATHER_OFFSET(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| template <size_t N, int kPow2> \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ |
| const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, \ |
| HWY_SVE_V(int, BITS) offset) { \ |
| return sv##OP##_s##BITS##offset_##CHAR##BITS(detail::MakeMask(d), base, \ |
| offset); \ |
| } |
| #define HWY_SVE_MASKED_GATHER_INDEX(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| template <size_t N, int kPow2> \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| NAME(svbool_t m, HWY_SVE_D(BASE, BITS, N, kPow2) d, \ |
| const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, \ |
| HWY_SVE_V(int, BITS) indices) { \ |
| const RebindToSigned<decltype(d)> di; \ |
| (void)di; /* for HWY_DASSERT */ \ |
| HWY_DASSERT(AllFalse(di, Lt(indices, Zero(di)))); \ |
| return sv##OP##_s##BITS##index_##CHAR##BITS(m, base, indices); \ |
| } |
| |
| HWY_SVE_FOREACH_UIF3264(HWY_SVE_GATHER_OFFSET, GatherOffset, ld1_gather) |
| HWY_SVE_FOREACH_UIF3264(HWY_SVE_MASKED_GATHER_INDEX, MaskedGatherIndex, |
| ld1_gather) |
| #undef HWY_SVE_GATHER_OFFSET |
| #undef HWY_SVE_MASKED_GATHER_INDEX |
| |
| template <class D> |
| HWY_API VFromD<D> MaskedGatherIndexOr(VFromD<D> no, svbool_t m, D d, |
| const TFromD<D>* HWY_RESTRICT p, |
| VFromD<RebindToSigned<D>> indices) { |
| return IfThenElse(m, MaskedGatherIndex(m, d, p, indices), no); |
| } |
| |
| template <class D> |
| HWY_API VFromD<D> GatherIndex(D d, const TFromD<D>* HWY_RESTRICT p, |
| VFromD<RebindToSigned<D>> indices) { |
| return MaskedGatherIndex(detail::MakeMask(d), d, p, indices); |
| } |
| |
| // ------------------------------ LoadInterleaved2 |
| |
| // Per-target flag to prevent generic_ops-inl.h from defining LoadInterleaved2. |
| #ifdef HWY_NATIVE_LOAD_STORE_INTERLEAVED |
| #undef HWY_NATIVE_LOAD_STORE_INTERLEAVED |
| #else |
| #define HWY_NATIVE_LOAD_STORE_INTERLEAVED |
| #endif |
| |
| #define HWY_SVE_LOAD2(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| template <size_t N, int kPow2> \ |
| HWY_API void NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ |
| const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned, \ |
| HWY_SVE_V(BASE, BITS) & v0, HWY_SVE_V(BASE, BITS) & v1) { \ |
| const HWY_SVE_TUPLE(BASE, BITS, 2) tuple = sv##OP##_##CHAR##BITS( \ |
| detail::MakeMask(d), detail::NativeLanePointer(unaligned)); \ |
| v0 = svget2(tuple, 0); \ |
| v1 = svget2(tuple, 1); \ |
| } |
| HWY_SVE_FOREACH(HWY_SVE_LOAD2, LoadInterleaved2, ld2) |
| HWY_SVE_FOREACH_BF16(HWY_SVE_LOAD2, LoadInterleaved2, ld2) |
| |
| #undef HWY_SVE_LOAD2 |
| |
| // ------------------------------ LoadInterleaved3 |
| |
| #define HWY_SVE_LOAD3(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| template <size_t N, int kPow2> \ |
| HWY_API void NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ |
| const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned, \ |
| HWY_SVE_V(BASE, BITS) & v0, HWY_SVE_V(BASE, BITS) & v1, \ |
| HWY_SVE_V(BASE, BITS) & v2) { \ |
| const HWY_SVE_TUPLE(BASE, BITS, 3) tuple = sv##OP##_##CHAR##BITS( \ |
| detail::MakeMask(d), detail::NativeLanePointer(unaligned)); \ |
| v0 = svget3(tuple, 0); \ |
| v1 = svget3(tuple, 1); \ |
| v2 = svget3(tuple, 2); \ |
| } |
| HWY_SVE_FOREACH(HWY_SVE_LOAD3, LoadInterleaved3, ld3) |
| HWY_SVE_FOREACH_BF16(HWY_SVE_LOAD3, LoadInterleaved3, ld3) |
| |
| #undef HWY_SVE_LOAD3 |
| |
| // ------------------------------ LoadInterleaved4 |
| |
| #define HWY_SVE_LOAD4(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| template <size_t N, int kPow2> \ |
| HWY_API void NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ |
| const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned, \ |
| HWY_SVE_V(BASE, BITS) & v0, HWY_SVE_V(BASE, BITS) & v1, \ |
| HWY_SVE_V(BASE, BITS) & v2, HWY_SVE_V(BASE, BITS) & v3) { \ |
| const HWY_SVE_TUPLE(BASE, BITS, 4) tuple = sv##OP##_##CHAR##BITS( \ |
| detail::MakeMask(d), detail::NativeLanePointer(unaligned)); \ |
| v0 = svget4(tuple, 0); \ |
| v1 = svget4(tuple, 1); \ |
| v2 = svget4(tuple, 2); \ |
| v3 = svget4(tuple, 3); \ |
| } |
| HWY_SVE_FOREACH(HWY_SVE_LOAD4, LoadInterleaved4, ld4) |
| HWY_SVE_FOREACH_BF16(HWY_SVE_LOAD4, LoadInterleaved4, ld4) |
| |
| #undef HWY_SVE_LOAD4 |
| |
| // ------------------------------ StoreInterleaved2 |
| |
| #define HWY_SVE_STORE2(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| template <size_t N, int kPow2> \ |
| HWY_API void NAME(HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1, \ |
| HWY_SVE_D(BASE, BITS, N, kPow2) d, \ |
| HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned) { \ |
| sv##OP##_##CHAR##BITS(detail::MakeMask(d), \ |
| detail::NativeLanePointer(unaligned), \ |
| Create2(d, v0, v1)); \ |
| } |
| HWY_SVE_FOREACH(HWY_SVE_STORE2, StoreInterleaved2, st2) |
| HWY_SVE_FOREACH_BF16(HWY_SVE_STORE2, StoreInterleaved2, st2) |
| |
| #undef HWY_SVE_STORE2 |
| |
| // ------------------------------ StoreInterleaved3 |
| |
| #define HWY_SVE_STORE3(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| template <size_t N, int kPow2> \ |
| HWY_API void NAME(HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1, \ |
| HWY_SVE_V(BASE, BITS) v2, \ |
| HWY_SVE_D(BASE, BITS, N, kPow2) d, \ |
| HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned) { \ |
| sv##OP##_##CHAR##BITS(detail::MakeMask(d), \ |
| detail::NativeLanePointer(unaligned), \ |
| Create3(d, v0, v1, v2)); \ |
| } |
| HWY_SVE_FOREACH(HWY_SVE_STORE3, StoreInterleaved3, st3) |
| HWY_SVE_FOREACH_BF16(HWY_SVE_STORE3, StoreInterleaved3, st3) |
| |
| #undef HWY_SVE_STORE3 |
| |
| // ------------------------------ StoreInterleaved4 |
| |
| #define HWY_SVE_STORE4(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| template <size_t N, int kPow2> \ |
| HWY_API void NAME(HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1, \ |
| HWY_SVE_V(BASE, BITS) v2, HWY_SVE_V(BASE, BITS) v3, \ |
| HWY_SVE_D(BASE, BITS, N, kPow2) d, \ |
| HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned) { \ |
| sv##OP##_##CHAR##BITS(detail::MakeMask(d), \ |
| detail::NativeLanePointer(unaligned), \ |
| Create4(d, v0, v1, v2, v3)); \ |
| } |
| HWY_SVE_FOREACH(HWY_SVE_STORE4, StoreInterleaved4, st4) |
| HWY_SVE_FOREACH_BF16(HWY_SVE_STORE4, StoreInterleaved4, st4) |
| |
| #undef HWY_SVE_STORE4 |
| |
| // Fall back on generic Load/StoreInterleaved[234] for any emulated types. |
| // Requires HWY_GENERIC_IF_EMULATED_D mirrors HWY_SVE_IF_EMULATED_D. |
| |
| // ================================================== CONVERT |
| |
| // ------------------------------ PromoteTo |
| |
| // Same sign |
| #define HWY_SVE_PROMOTE_TO(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| template <size_t N, int kPow2> \ |
| HWY_API HWY_SVE_V(BASE, BITS) NAME( \ |
| HWY_SVE_D(BASE, BITS, N, kPow2) /* tag */, HWY_SVE_V(BASE, HALF) v) { \ |
| return sv##OP##_##CHAR##BITS(v); \ |
| } |
| |
| HWY_SVE_FOREACH_UI16(HWY_SVE_PROMOTE_TO, PromoteTo, unpklo) |
| HWY_SVE_FOREACH_UI32(HWY_SVE_PROMOTE_TO, PromoteTo, unpklo) |
| HWY_SVE_FOREACH_UI64(HWY_SVE_PROMOTE_TO, PromoteTo, unpklo) |
| |
| // 2x |
| template <size_t N, int kPow2> |
| HWY_API svuint32_t PromoteTo(Simd<uint32_t, N, kPow2> dto, svuint8_t vfrom) { |
| const RepartitionToWide<DFromV<decltype(vfrom)>> d2; |
| return PromoteTo(dto, PromoteTo(d2, vfrom)); |
| } |
| template <size_t N, int kPow2> |
| HWY_API svint32_t PromoteTo(Simd<int32_t, N, kPow2> dto, svint8_t vfrom) { |
| const RepartitionToWide<DFromV<decltype(vfrom)>> d2; |
| return PromoteTo(dto, PromoteTo(d2, vfrom)); |
| } |
| template <size_t N, int kPow2> |
| HWY_API svuint64_t PromoteTo(Simd<uint64_t, N, kPow2> dto, svuint16_t vfrom) { |
| const RepartitionToWide<DFromV<decltype(vfrom)>> d2; |
| return PromoteTo(dto, PromoteTo(d2, vfrom)); |
| } |
| template <size_t N, int kPow2> |
| HWY_API svint64_t PromoteTo(Simd<int64_t, N, kPow2> dto, svint16_t vfrom) { |
| const RepartitionToWide<DFromV<decltype(vfrom)>> d2; |
| return PromoteTo(dto, PromoteTo(d2, vfrom)); |
| } |
| |
| // 3x |
| template <size_t N, int kPow2> |
| HWY_API svuint64_t PromoteTo(Simd<uint64_t, N, kPow2> dto, svuint8_t vfrom) { |
| const RepartitionToNarrow<decltype(dto)> d4; |
| const RepartitionToNarrow<decltype(d4)> d2; |
| return PromoteTo(dto, PromoteTo(d4, PromoteTo(d2, vfrom))); |
| } |
| template <size_t N, int kPow2> |
| HWY_API svint64_t PromoteTo(Simd<int64_t, N, kPow2> dto, svint8_t vfrom) { |
| const RepartitionToNarrow<decltype(dto)> d4; |
| const RepartitionToNarrow<decltype(d4)> d2; |
| return PromoteTo(dto, PromoteTo(d4, PromoteTo(d2, vfrom))); |
| } |
| |
| // Sign change |
| template <class D, class V, HWY_IF_SIGNED_D(D), HWY_IF_UNSIGNED_V(V), |
| HWY_IF_LANES_GT(sizeof(TFromD<D>), sizeof(TFromV<V>))> |
| HWY_API VFromD<D> PromoteTo(D di, V v) { |
| const RebindToUnsigned<decltype(di)> du; |
| return BitCast(di, PromoteTo(du, v)); |
| } |
| |
| // ------------------------------ PromoteTo F |
| |
| // Per-target flag to prevent generic_ops-inl.h from defining f16 conversions. |
| #ifdef HWY_NATIVE_F16C |
| #undef HWY_NATIVE_F16C |
| #else |
| #define HWY_NATIVE_F16C |
| #endif |
| |
| // Unlike Highway's ZipLower, this returns the same type. |
| namespace detail { |
| HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, ZipLowerSame, zip1) |
| } // namespace detail |
| |
| template <size_t N, int kPow2> |
| HWY_API svfloat32_t PromoteTo(Simd<float32_t, N, kPow2> /* d */, |
| const svfloat16_t v) { |
| // svcvt* expects inputs in even lanes, whereas Highway wants lower lanes, so |
| // first replicate each lane once. |
| const svfloat16_t vv = detail::ZipLowerSame(v, v); |
| return svcvt_f32_f16_x(detail::PTrue(Simd<float16_t, N, kPow2>()), vv); |
| } |
| |
| #ifdef HWY_NATIVE_PROMOTE_F16_TO_F64 |
| #undef HWY_NATIVE_PROMOTE_F16_TO_F64 |
| #else |
| #define HWY_NATIVE_PROMOTE_F16_TO_F64 |
| #endif |
| |
| template <size_t N, int kPow2> |
| HWY_API svfloat64_t PromoteTo(Simd<float64_t, N, kPow2> /* d */, |
| const svfloat16_t v) { |
| // svcvt* expects inputs in even lanes, whereas Highway wants lower lanes, so |
| // first replicate each lane once. |
| const svfloat16_t vv = detail::ZipLowerSame(v, v); |
| return svcvt_f64_f16_x(detail::PTrue(Simd<float16_t, N, kPow2>()), |
| detail::ZipLowerSame(vv, vv)); |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API svfloat64_t PromoteTo(Simd<float64_t, N, kPow2> /* d */, |
| const svfloat32_t v) { |
| const svfloat32_t vv = detail::ZipLowerSame(v, v); |
| return svcvt_f64_f32_x(detail::PTrue(Simd<float32_t, N, kPow2>()), vv); |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API svfloat64_t PromoteTo(Simd<float64_t, N, kPow2> /* d */, |
| const svint32_t v) { |
| const svint32_t vv = detail::ZipLowerSame(v, v); |
| return svcvt_f64_s32_x(detail::PTrue(Simd<int32_t, N, kPow2>()), vv); |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API svfloat64_t PromoteTo(Simd<float64_t, N, kPow2> /* d */, |
| const svuint32_t v) { |
| const svuint32_t vv = detail::ZipLowerSame(v, v); |
| return svcvt_f64_u32_x(detail::PTrue(Simd<uint32_t, N, kPow2>()), vv); |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API svint64_t PromoteTo(Simd<int64_t, N, kPow2> /* d */, |
| const svfloat32_t v) { |
| const svfloat32_t vv = detail::ZipLowerSame(v, v); |
| return svcvt_s64_f32_x(detail::PTrue(Simd<float, N, kPow2>()), vv); |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API svuint64_t PromoteTo(Simd<uint64_t, N, kPow2> /* d */, |
| const svfloat32_t v) { |
| const svfloat32_t vv = detail::ZipLowerSame(v, v); |
| return svcvt_u64_f32_x(detail::PTrue(Simd<float, N, kPow2>()), vv); |
| } |
| |
| // ------------------------------ PromoteUpperTo |
| |
| namespace detail { |
| HWY_SVE_FOREACH_UI16(HWY_SVE_PROMOTE_TO, PromoteUpperTo, unpkhi) |
| HWY_SVE_FOREACH_UI32(HWY_SVE_PROMOTE_TO, PromoteUpperTo, unpkhi) |
| HWY_SVE_FOREACH_UI64(HWY_SVE_PROMOTE_TO, PromoteUpperTo, unpkhi) |
| #undef HWY_SVE_PROMOTE_TO |
| } // namespace detail |
| |
| #ifdef HWY_NATIVE_PROMOTE_UPPER_TO |
| #undef HWY_NATIVE_PROMOTE_UPPER_TO |
| #else |
| #define HWY_NATIVE_PROMOTE_UPPER_TO |
| #endif |
| |
| // Unsigned->Unsigned or Signed->Signed |
| template <class D, class V, typename TD = TFromD<D>, typename TV = TFromV<V>, |
| hwy::EnableIf<IsInteger<TD>() && IsInteger<TV>() && |
| (IsSigned<TD>() == IsSigned<TV>())>* = nullptr> |
| HWY_API VFromD<D> PromoteUpperTo(D d, V v) { |
| if (detail::IsFull(d)) { |
| return detail::PromoteUpperTo(d, v); |
| } |
| const Rebind<TFromV<V>, decltype(d)> dh; |
| return PromoteTo(d, UpperHalf(dh, v)); |
| } |
| |
| // Differing signs or either is float |
| template <class D, class V, typename TD = TFromD<D>, typename TV = TFromV<V>, |
| hwy::EnableIf<!IsInteger<TD>() || !IsInteger<TV>() || |
| (IsSigned<TD>() != IsSigned<TV>())>* = nullptr> |
| HWY_API VFromD<D> PromoteUpperTo(D d, V v) { |
| // Lanes(d) may differ from Lanes(DFromV<V>()). Use the lane type from V |
| // because it cannot be deduced from D (could be either bf16 or f16). |
| const Rebind<TFromV<V>, decltype(d)> dh; |
| return PromoteTo(d, UpperHalf(dh, v)); |
| } |
| |
| // ------------------------------ DemoteTo U |
| |
| namespace detail { |
| |
| // Saturates unsigned vectors to half/quarter-width TN. |
| template <typename TN, class VU> |
| VU SaturateU(VU v) { |
| return detail::MinN(v, static_cast<TFromV<VU>>(LimitsMax<TN>())); |
| } |
| |
| // Saturates unsigned vectors to half/quarter-width TN. |
| template <typename TN, class VI> |
| VI SaturateI(VI v) { |
| return detail::MinN(detail::MaxN(v, LimitsMin<TN>()), LimitsMax<TN>()); |
| } |
| |
| } // namespace detail |
| |
| template <size_t N, int kPow2> |
| HWY_API svuint8_t DemoteTo(Simd<uint8_t, N, kPow2> dn, const svint16_t v) { |
| #if HWY_SVE_HAVE_2 |
| const svuint8_t vn = BitCast(dn, svqxtunb_s16(v)); |
| #else |
| const DFromV<decltype(v)> di; |
| const RebindToUnsigned<decltype(di)> du; |
| using TN = TFromD<decltype(dn)>; |
| // First clamp negative numbers to zero and cast to unsigned. |
| const svuint16_t clamped = BitCast(du, detail::MaxN(v, 0)); |
| // Saturate to unsigned-max and halve the width. |
| const svuint8_t vn = BitCast(dn, detail::SaturateU<TN>(clamped)); |
| #endif |
| return svuzp1_u8(vn, vn); |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API svuint16_t DemoteTo(Simd<uint16_t, N, kPow2> dn, const svint32_t v) { |
| #if HWY_SVE_HAVE_2 |
| const svuint16_t vn = BitCast(dn, svqxtunb_s32(v)); |
| #else |
| const DFromV<decltype(v)> di; |
| const RebindToUnsigned<decltype(di)> du; |
| using TN = TFromD<decltype(dn)>; |
| // First clamp negative numbers to zero and cast to unsigned. |
| const svuint32_t clamped = BitCast(du, detail::MaxN(v, 0)); |
| // Saturate to unsigned-max and halve the width. |
| const svuint16_t vn = BitCast(dn, detail::SaturateU<TN>(clamped)); |
| #endif |
| return svuzp1_u16(vn, vn); |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API svuint8_t DemoteTo(Simd<uint8_t, N, kPow2> dn, const svint32_t v) { |
| const DFromV<decltype(v)> di; |
| const RebindToUnsigned<decltype(di)> du; |
| const RepartitionToNarrow<decltype(du)> d2; |
| #if HWY_SVE_HAVE_2 |
| const svuint16_t cast16 = BitCast(d2, svqxtnb_u16(svqxtunb_s32(v))); |
| #else |
| using TN = TFromD<decltype(dn)>; |
| // First clamp negative numbers to zero and cast to unsigned. |
| const svuint32_t clamped = BitCast(du, detail::MaxN(v, 0)); |
| // Saturate to unsigned-max and quarter the width. |
| const svuint16_t cast16 = BitCast(d2, detail::SaturateU<TN>(clamped)); |
| #endif |
| const svuint8_t x2 = BitCast(dn, svuzp1_u16(cast16, cast16)); |
| return svuzp1_u8(x2, x2); |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API svuint8_t DemoteTo(Simd<uint8_t, N, kPow2> dn, const svuint16_t v) { |
| #if HWY_SVE_HAVE_2 |
| const svuint8_t vn = BitCast(dn, svqxtnb_u16(v)); |
| #else |
| using TN = TFromD<decltype(dn)>; |
| const svuint8_t vn = BitCast(dn, detail::SaturateU<TN>(v)); |
| #endif |
| return svuzp1_u8(vn, vn); |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API svuint16_t DemoteTo(Simd<uint16_t, N, kPow2> dn, const svuint32_t v) { |
| #if HWY_SVE_HAVE_2 |
| const svuint16_t vn = BitCast(dn, svqxtnb_u32(v)); |
| #else |
| using TN = TFromD<decltype(dn)>; |
| const svuint16_t vn = BitCast(dn, detail::SaturateU<TN>(v)); |
| #endif |
| return svuzp1_u16(vn, vn); |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API svuint8_t DemoteTo(Simd<uint8_t, N, kPow2> dn, const svuint32_t v) { |
| using TN = TFromD<decltype(dn)>; |
| return U8FromU32(detail::SaturateU<TN>(v)); |
| } |
| |
| // ------------------------------ Truncations |
| |
| template <size_t N, int kPow2> |
| HWY_API svuint8_t TruncateTo(Simd<uint8_t, N, kPow2> /* tag */, |
| const svuint64_t v) { |
| const DFromV<svuint8_t> d; |
| const svuint8_t v1 = BitCast(d, v); |
| const svuint8_t v2 = svuzp1_u8(v1, v1); |
| const svuint8_t v3 = svuzp1_u8(v2, v2); |
| return svuzp1_u8(v3, v3); |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API svuint16_t TruncateTo(Simd<uint16_t, N, kPow2> /* tag */, |
| const svuint64_t v) { |
| const DFromV<svuint16_t> d; |
| const svuint16_t v1 = BitCast(d, v); |
| const svuint16_t v2 = svuzp1_u16(v1, v1); |
| return svuzp1_u16(v2, v2); |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API svuint32_t TruncateTo(Simd<uint32_t, N, kPow2> /* tag */, |
| const svuint64_t v) { |
| const DFromV<svuint32_t> d; |
| const svuint32_t v1 = BitCast(d, v); |
| return svuzp1_u32(v1, v1); |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API svuint8_t TruncateTo(Simd<uint8_t, N, kPow2> /* tag */, |
| const svuint32_t v) { |
| const DFromV<svuint8_t> d; |
| const svuint8_t v1 = BitCast(d, v); |
| const svuint8_t v2 = svuzp1_u8(v1, v1); |
| return svuzp1_u8(v2, v2); |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API svuint16_t TruncateTo(Simd<uint16_t, N, kPow2> /* tag */, |
| const svuint32_t v) { |
| const DFromV<svuint16_t> d; |
| const svuint16_t v1 = BitCast(d, v); |
| return svuzp1_u16(v1, v1); |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API svuint8_t TruncateTo(Simd<uint8_t, N, kPow2> /* tag */, |
| const svuint16_t v) { |
| const DFromV<svuint8_t> d; |
| const svuint8_t v1 = BitCast(d, v); |
| return svuzp1_u8(v1, v1); |
| } |
| |
| // ------------------------------ DemoteTo I |
| |
| template <size_t N, int kPow2> |
| HWY_API svint8_t DemoteTo(Simd<int8_t, N, kPow2> dn, const svint16_t v) { |
| #if HWY_SVE_HAVE_2 |
| const svint8_t vn = BitCast(dn, svqxtnb_s16(v)); |
| #else |
| using TN = TFromD<decltype(dn)>; |
| const svint8_t vn = BitCast(dn, detail::SaturateI<TN>(v)); |
| #endif |
| return svuzp1_s8(vn, vn); |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API svint16_t DemoteTo(Simd<int16_t, N, kPow2> dn, const svint32_t v) { |
| #if HWY_SVE_HAVE_2 |
| const svint16_t vn = BitCast(dn, svqxtnb_s32(v)); |
| #else |
| using TN = TFromD<decltype(dn)>; |
| const svint16_t vn = BitCast(dn, detail::SaturateI<TN>(v)); |
| #endif |
| return svuzp1_s16(vn, vn); |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API svint8_t DemoteTo(Simd<int8_t, N, kPow2> dn, const svint32_t v) { |
| const RepartitionToWide<decltype(dn)> d2; |
| #if HWY_SVE_HAVE_2 |
| const svint16_t cast16 = BitCast(d2, svqxtnb_s16(svqxtnb_s32(v))); |
| #else |
| using TN = TFromD<decltype(dn)>; |
| const svint16_t cast16 = BitCast(d2, detail::SaturateI<TN>(v)); |
| #endif |
| const svint8_t v2 = BitCast(dn, svuzp1_s16(cast16, cast16)); |
| return BitCast(dn, svuzp1_s8(v2, v2)); |
| } |
| |
| // ------------------------------ I64/U64 DemoteTo |
| |
| template <size_t N, int kPow2> |
| HWY_API svint32_t DemoteTo(Simd<int32_t, N, kPow2> dn, const svint64_t v) { |
| const Rebind<uint64_t, decltype(dn)> du64; |
| const RebindToUnsigned<decltype(dn)> dn_u; |
| #if HWY_SVE_HAVE_2 |
| const svuint64_t vn = BitCast(du64, svqxtnb_s64(v)); |
| #else |
| using TN = TFromD<decltype(dn)>; |
| const svuint64_t vn = BitCast(du64, detail::SaturateI<TN>(v)); |
| #endif |
| return BitCast(dn, TruncateTo(dn_u, vn)); |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API svint16_t DemoteTo(Simd<int16_t, N, kPow2> dn, const svint64_t v) { |
| const Rebind<uint64_t, decltype(dn)> du64; |
| const RebindToUnsigned<decltype(dn)> dn_u; |
| #if HWY_SVE_HAVE_2 |
| const svuint64_t vn = BitCast(du64, svqxtnb_s32(svqxtnb_s64(v))); |
| #else |
| using TN = TFromD<decltype(dn)>; |
| const svuint64_t vn = BitCast(du64, detail::SaturateI<TN>(v)); |
| #endif |
| return BitCast(dn, TruncateTo(dn_u, vn)); |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API svint8_t DemoteTo(Simd<int8_t, N, kPow2> dn, const svint64_t v) { |
| const Rebind<uint64_t, decltype(dn)> du64; |
| const RebindToUnsigned<decltype(dn)> dn_u; |
| using TN = TFromD<decltype(dn)>; |
| const svuint64_t vn = BitCast(du64, detail::SaturateI<TN>(v)); |
| return BitCast(dn, TruncateTo(dn_u, vn)); |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API svuint32_t DemoteTo(Simd<uint32_t, N, kPow2> dn, const svint64_t v) { |
| const Rebind<uint64_t, decltype(dn)> du64; |
| #if HWY_SVE_HAVE_2 |
| const svuint64_t vn = BitCast(du64, svqxtunb_s64(v)); |
| #else |
| using TN = TFromD<decltype(dn)>; |
| // First clamp negative numbers to zero and cast to unsigned. |
| const svuint64_t clamped = BitCast(du64, detail::MaxN(v, 0)); |
| // Saturate to unsigned-max |
| const svuint64_t vn = detail::SaturateU<TN>(clamped); |
| #endif |
| return TruncateTo(dn, vn); |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API svuint16_t DemoteTo(Simd<uint16_t, N, kPow2> dn, const svint64_t v) { |
| const Rebind<uint64_t, decltype(dn)> du64; |
| #if HWY_SVE_HAVE_2 |
| const svuint64_t vn = BitCast(du64, svqxtnb_u32(svqxtunb_s64(v))); |
| #else |
| using TN = TFromD<decltype(dn)>; |
| // First clamp negative numbers to zero and cast to unsigned. |
| const svuint64_t clamped = BitCast(du64, detail::MaxN(v, 0)); |
| // Saturate to unsigned-max |
| const svuint64_t vn = detail::SaturateU<TN>(clamped); |
| #endif |
| return TruncateTo(dn, vn); |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API svuint8_t DemoteTo(Simd<uint8_t, N, kPow2> dn, const svint64_t v) { |
| const Rebind<uint64_t, decltype(dn)> du64; |
| using TN = TFromD<decltype(dn)>; |
| // First clamp negative numbers to zero and cast to unsigned. |
| const svuint64_t clamped = BitCast(du64, detail::MaxN(v, 0)); |
| // Saturate to unsigned-max |
| const svuint64_t vn = detail::SaturateU<TN>(clamped); |
| return TruncateTo(dn, vn); |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API svuint32_t DemoteTo(Simd<uint32_t, N, kPow2> dn, const svuint64_t v) { |
| const Rebind<uint64_t, decltype(dn)> du64; |
| #if HWY_SVE_HAVE_2 |
| const svuint64_t vn = BitCast(du64, svqxtnb_u64(v)); |
| #else |
| using TN = TFromD<decltype(dn)>; |
| const svuint64_t vn = BitCast(du64, detail::SaturateU<TN>(v)); |
| #endif |
| return TruncateTo(dn, vn); |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API svuint16_t DemoteTo(Simd<uint16_t, N, kPow2> dn, const svuint64_t v) { |
| const Rebind<uint64_t, decltype(dn)> du64; |
| #if HWY_SVE_HAVE_2 |
| const svuint64_t vn = BitCast(du64, svqxtnb_u32(svqxtnb_u64(v))); |
| #else |
| using TN = TFromD<decltype(dn)>; |
| const svuint64_t vn = BitCast(du64, detail::SaturateU<TN>(v)); |
| #endif |
| return TruncateTo(dn, vn); |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API svuint8_t DemoteTo(Simd<uint8_t, N, kPow2> dn, const svuint64_t v) { |
| const Rebind<uint64_t, decltype(dn)> du64; |
| using TN = TFromD<decltype(dn)>; |
| const svuint64_t vn = BitCast(du64, detail::SaturateU<TN>(v)); |
| return TruncateTo(dn, vn); |
| } |
| |
| // ------------------------------ Unsigned to signed demotions |
| |
| // Disable the default unsigned to signed DemoteTo/ReorderDemote2To |
| // implementations in generic_ops-inl.h on SVE/SVE2 as the SVE/SVE2 targets have |
| // target-specific implementations of the unsigned to signed DemoteTo and |
| // ReorderDemote2To ops |
| |
| // NOTE: hwy::EnableIf<!hwy::IsSame<V, V>()>* = nullptr is used instead of |
| // hwy::EnableIf<false>* = nullptr to avoid compiler errors since |
| // !hwy::IsSame<V, V>() is always false and as !hwy::IsSame<V, V>() will cause |
| // SFINAE to occur instead of a hard error due to a dependency on the V template |
| // argument |
| #undef HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V |
| #define HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V(V) \ |
| hwy::EnableIf<!hwy::IsSame<V, V>()>* = nullptr |
| |
| template <class D, class V, HWY_IF_SIGNED_D(D), HWY_IF_UNSIGNED_V(V), |
| HWY_IF_T_SIZE_LE_D(D, sizeof(TFromV<V>) - 1)> |
| HWY_API VFromD<D> DemoteTo(D dn, V v) { |
| const RebindToUnsigned<D> dn_u; |
| return BitCast(dn, TruncateTo(dn_u, detail::SaturateU<TFromD<D>>(v))); |
| } |
| |
| // ------------------------------ PromoteEvenTo/PromoteOddTo |
| |
| // Signed to signed PromoteEvenTo: 1 instruction instead of 2 in generic-inl.h. |
| // Might as well also enable unsigned to unsigned, though it is just an And. |
| namespace detail { |
| HWY_SVE_FOREACH_UI16(HWY_SVE_RETV_ARGPV, NativePromoteEvenTo, extb) |
| HWY_SVE_FOREACH_UI32(HWY_SVE_RETV_ARGPV, NativePromoteEvenTo, exth) |
| HWY_SVE_FOREACH_UI64(HWY_SVE_RETV_ARGPV, NativePromoteEvenTo, extw) |
| } // namespace detail |
| |
| #include "third_party/highway/hwy/ops/inside-inl.h" |
| |
| // ------------------------------ DemoteTo F |
| |
| // We already toggled HWY_NATIVE_F16C above. |
| |
| template <size_t N, int kPow2> |
| HWY_API svfloat16_t DemoteTo(Simd<float16_t, N, kPow2> d, const svfloat32_t v) { |
| const svfloat16_t in_even = svcvt_f16_f32_x(detail::PTrue(d), v); |
| return detail::ConcatEvenFull(in_even, |
| in_even); // lower half |
| } |
| |
| #ifdef HWY_NATIVE_DEMOTE_F64_TO_F16 |
| #undef HWY_NATIVE_DEMOTE_F64_TO_F16 |
| #else |
| #define HWY_NATIVE_DEMOTE_F64_TO_F16 |
| #endif |
| |
| template <size_t N, int kPow2> |
| HWY_API svfloat16_t DemoteTo(Simd<float16_t, N, kPow2> d, const svfloat64_t v) { |
| const svfloat16_t in_lo16 = svcvt_f16_f64_x(detail::PTrue(d), v); |
| const svfloat16_t in_even = detail::ConcatEvenFull(in_lo16, in_lo16); |
| return detail::ConcatEvenFull(in_even, |
| in_even); // lower half |
| } |
| |
| #ifdef HWY_NATIVE_DEMOTE_F32_TO_BF16 |
| #undef HWY_NATIVE_DEMOTE_F32_TO_BF16 |
| #else |
| #define HWY_NATIVE_DEMOTE_F32_TO_BF16 |
| #endif |
| |
| #if !HWY_SVE_HAVE_F32_TO_BF16C |
| namespace detail { |
| |
| // Round a F32 value to the nearest BF16 value, with the result returned as the |
| // rounded F32 value bitcasted to an U32 |
| |
| // RoundF32ForDemoteToBF16 also converts NaN values to QNaN values to prevent |
| // NaN F32 values from being converted to an infinity |
| HWY_INLINE svuint32_t RoundF32ForDemoteToBF16(svfloat32_t v) { |
| const DFromV<decltype(v)> df32; |
| const RebindToUnsigned<decltype(df32)> du32; |
| |
| const auto is_non_nan = Eq(v, v); |
| const auto bits32 = BitCast(du32, v); |
| |
| const auto round_incr = |
| detail::AddN(detail::AndN(ShiftRight<16>(bits32), 1u), 0x7FFFu); |
| return MaskedAddOr(detail::OrN(bits32, 0x00400000u), is_non_nan, bits32, |
| round_incr); |
| } |
| |
| } // namespace detail |
| #endif // !HWY_SVE_HAVE_F32_TO_BF16C |
| |
| template <size_t N, int kPow2> |
| HWY_API VBF16 DemoteTo(Simd<bfloat16_t, N, kPow2> dbf16, svfloat32_t v) { |
| #if HWY_SVE_HAVE_F32_TO_BF16C |
| const VBF16 in_even = svcvt_bf16_f32_x(detail::PTrue(dbf16), v); |
| return detail::ConcatEvenFull(in_even, in_even); |
| #else |
| const svuint16_t in_odd = |
| BitCast(ScalableTag<uint16_t>(), detail::RoundF32ForDemoteToBF16(v)); |
| return BitCast(dbf16, detail::ConcatOddFull(in_odd, in_odd)); // lower half |
| #endif |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API svfloat32_t DemoteTo(Simd<float32_t, N, kPow2> d, const svfloat64_t v) { |
| const svfloat32_t in_even = svcvt_f32_f64_x(detail::PTrue(d), v); |
| return detail::ConcatEvenFull(in_even, |
| in_even); // lower half |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API svint32_t DemoteTo(Simd<int32_t, N, kPow2> d, const svfloat64_t v) { |
| const svint32_t in_even = svcvt_s32_f64_x(detail::PTrue(d), v); |
| return detail::ConcatEvenFull(in_even, |
| in_even); // lower half |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API svuint32_t DemoteTo(Simd<uint32_t, N, kPow2> d, const svfloat64_t v) { |
| const svuint32_t in_even = svcvt_u32_f64_x(detail::PTrue(d), v); |
| return detail::ConcatEvenFull(in_even, |
| in_even); // lower half |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API svfloat32_t DemoteTo(Simd<float, N, kPow2> d, const svint64_t v) { |
| const svfloat32_t in_even = svcvt_f32_s64_x(detail::PTrue(d), v); |
| return detail::ConcatEvenFull(in_even, |
| in_even); // lower half |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API svfloat32_t DemoteTo(Simd<float, N, kPow2> d, const svuint64_t v) { |
| const svfloat32_t in_even = svcvt_f32_u64_x(detail::PTrue(d), v); |
| return detail::ConcatEvenFull(in_even, |
| in_even); // lower half |
| } |
| |
| // ------------------------------ ConvertTo F |
| |
| #define HWY_SVE_CONVERT(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| /* Float from signed */ \ |
| template <size_t N, int kPow2> \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, HWY_SVE_V(int, BITS) v) { \ |
| return sv##OP##_##CHAR##BITS##_s##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ |
| } \ |
| /* Float from unsigned */ \ |
| template <size_t N, int kPow2> \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, HWY_SVE_V(uint, BITS) v) { \ |
| return sv##OP##_##CHAR##BITS##_u##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ |
| } \ |
| /* Signed from float, rounding toward zero */ \ |
| template <size_t N, int kPow2> \ |
| HWY_API HWY_SVE_V(int, BITS) \ |
| NAME(HWY_SVE_D(int, BITS, N, kPow2) /* d */, HWY_SVE_V(BASE, BITS) v) { \ |
| return sv##OP##_s##BITS##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ |
| } \ |
| /* Unsigned from float, rounding toward zero */ \ |
| template <size_t N, int kPow2> \ |
| HWY_API HWY_SVE_V(uint, BITS) \ |
| NAME(HWY_SVE_D(uint, BITS, N, kPow2) /* d */, HWY_SVE_V(BASE, BITS) v) { \ |
| return sv##OP##_u##BITS##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ |
| } |
| |
| HWY_SVE_FOREACH_F(HWY_SVE_CONVERT, ConvertTo, cvt) |
| #undef HWY_SVE_CONVERT |
| |
| // ------------------------------ MaskedConvertTo F |
| |
| #define HWY_SVE_MASKED_CONVERT_TO_OR_ZERO(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| /* Float from signed */ \ |
| template <size_t N, int kPow2> \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| NAME(svbool_t m, HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ |
| HWY_SVE_V(int, BITS) v) { \ |
| return sv##OP##_##CHAR##BITS##_s##BITS##_z(m, v); \ |
| } \ |
| /* Float from unsigned */ \ |
| template <size_t N, int kPow2> \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| NAME(svbool_t m, HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ |
| HWY_SVE_V(uint, BITS) v) { \ |
| return sv##OP##_##CHAR##BITS##_u##BITS##_z(m, v); \ |
| } \ |
| /* Signed from float, rounding toward zero */ \ |
| template <size_t N, int kPow2> \ |
| HWY_API HWY_SVE_V(int, BITS) \ |
| NAME(svbool_t m, HWY_SVE_D(int, BITS, N, kPow2) /* d */, \ |
| HWY_SVE_V(BASE, BITS) v) { \ |
| return sv##OP##_s##BITS##_##CHAR##BITS##_z(m, v); \ |
| } \ |
| /* Unsigned from float, rounding toward zero */ \ |
| template <size_t N, int kPow2> \ |
| HWY_API HWY_SVE_V(uint, BITS) \ |
| NAME(svbool_t m, HWY_SVE_D(uint, BITS, N, kPow2) /* d */, \ |
| HWY_SVE_V(BASE, BITS) v) { \ |
| return sv##OP##_u##BITS##_##CHAR##BITS##_z(m, v); \ |
| } |
| |
| HWY_SVE_FOREACH_F(HWY_SVE_MASKED_CONVERT_TO_OR_ZERO, MaskedConvertTo, cvt) |
| #undef HWY_SVE_MASKED_CONVERT_TO_OR_ZERO |
| |
| // ------------------------------ NearestInt (Round, ConvertTo) |
| template <class VF, class DI = RebindToSigned<DFromV<VF>>> |
| HWY_API VFromD<DI> NearestInt(VF v) { |
| // No single instruction, round then truncate. |
| return ConvertTo(DI(), Round(v)); |
| } |
| |
| template <class DI32, HWY_IF_I32_D(DI32)> |
| HWY_API VFromD<DI32> DemoteToNearestInt(DI32 di32, |
| VFromD<Rebind<double, DI32>> v) { |
| // No single instruction, round then demote. |
| return DemoteTo(di32, Round(v)); |
| } |
| |
| // ------------------------------ Iota (AddN, ConvertTo) |
| |
| #define HWY_SVE_IOTA(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| template <size_t N, int kPow2, typename T2> \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, T2 first) { \ |
| return sv##OP##_##CHAR##BITS( \ |
| ConvertScalarTo<HWY_SVE_T(BASE, BITS)>(first), 1); \ |
| } |
| |
| HWY_SVE_FOREACH_UI(HWY_SVE_IOTA, Iota, index) |
| #undef HWY_SVE_IOTA |
| |
| template <class D, typename T = TFromD<D>, typename T2, HWY_IF_FLOAT(T)> |
| HWY_API VFromD<D> Iota(const D d, T2 first) { |
| const RebindToSigned<D> di; |
| const T first_f = ConvertScalarTo<T>(first); |
| const VFromD<D> iota_f = ConvertTo(d, Iota(di, 0)); |
| return detail::AddN(iota_f, first_f); |
| } |
| |
| // ================================================== LANE ACCESS |
| |
| // ------------------------------ ExtractLane (GetLaneM, FirstN) |
| template <class V> |
| HWY_API TFromV<V> ExtractLane(V v, size_t i) { |
| return detail::GetLaneM(v, FirstN(DFromV<V>(), i)); |
| } |
| |
| // ------------------------------ InsertLane (IfThenElse, EqN) |
| template <class V, typename T> |
| HWY_API V InsertLane(const V v, size_t i, T t) { |
| static_assert(sizeof(TFromV<V>) == sizeof(T), "Lane size mismatch"); |
| const DFromV<V> d; |
| const RebindToSigned<decltype(d)> di; |
| using TI = TFromD<decltype(di)>; |
| const svbool_t is_i = detail::EqN(Iota(di, 0), static_cast<TI>(i)); |
| // The actual type may be int16_t for special floats; copy, not cast. |
| TFromV<V> t_bits; |
| hwy::CopySameSize(&t, &t_bits); |
| return IfThenElse(RebindMask(d, is_i), Set(d, t_bits), v); |
| } |
| |
| // ------------------------------ GetExponent |
| |
| #if HWY_SVE_HAVE_2 || HWY_IDE |
| #ifdef HWY_NATIVE_GET_EXPONENT |
| #undef HWY_NATIVE_GET_EXPONENT |
| #else |
| #define HWY_NATIVE_GET_EXPONENT |
| #endif |
| |
| namespace detail { |
| #define HWY_SVE_GET_EXP(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API HWY_SVE_V(int, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ |
| return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ |
| } |
| HWY_SVE_FOREACH_F(HWY_SVE_GET_EXP, GetExponent, logb) |
| #undef HWY_SVE_GET_EXP |
| } // namespace detail |
| |
| template <class V, HWY_IF_FLOAT_V(V)> |
| HWY_API V GetExponent(V v) { |
| const DFromV<V> d; |
| const RebindToSigned<decltype(d)> di; |
| const VFromD<decltype(di)> exponent_int = detail::GetExponent(v); |
| // convert integer to original type |
| return ConvertTo(d, exponent_int); |
| } |
| #endif // HWY_SVE_HAVE_2 |
| |
| // ------------------------------ InterleaveLower |
| |
| template <class D, class V> |
| HWY_API V InterleaveLower(D d, const V a, const V b) { |
| static_assert(IsSame<TFromD<D>, TFromV<V>>(), "D/V mismatch"); |
| #if HWY_TARGET == HWY_SVE2_128 |
| (void)d; |
| return detail::ZipLowerSame(a, b); |
| #else |
| // Move lower halves of blocks to lower half of vector. |
| const Repartition<uint64_t, decltype(d)> d64; |
| const auto a64 = BitCast(d64, a); |
| const auto b64 = BitCast(d64, b); |
| const auto a_blocks = detail::ConcatEvenFull(a64, a64); // lower half |
| const auto b_blocks = detail::ConcatEvenFull(b64, b64); |
| return detail::ZipLowerSame(BitCast(d, a_blocks), BitCast(d, b_blocks)); |
| #endif |
| } |
| |
| template <class V> |
| HWY_API V InterleaveLower(const V a, const V b) { |
| return InterleaveLower(DFromV<V>(), a, b); |
| } |
| |
| // ------------------------------ InterleaveUpper |
| |
| // Only use zip2 if vector are a powers of two, otherwise getting the actual |
| // "upper half" requires MaskUpperHalf. |
| namespace detail { |
| // Unlike Highway's ZipUpper, this returns the same type. |
| HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, ZipUpperSame, zip2) |
| } // namespace detail |
| |
| // Full vector: guaranteed to have at least one block |
| template <class D, class V = VFromD<D>, |
| hwy::EnableIf<detail::IsFull(D())>* = nullptr> |
| HWY_API V InterleaveUpper(D d, const V a, const V b) { |
| #if HWY_TARGET == HWY_SVE2_128 |
| (void)d; |
| return detail::ZipUpperSame(a, b); |
| #else |
| // Move upper halves of blocks to lower half of vector. |
| const Repartition<uint64_t, decltype(d)> d64; |
| const auto a64 = BitCast(d64, a); |
| const auto b64 = BitCast(d64, b); |
| const auto a_blocks = detail::ConcatOddFull(a64, a64); // lower half |
| const auto b_blocks = detail::ConcatOddFull(b64, b64); |
| return detail::ZipLowerSame(BitCast(d, a_blocks), BitCast(d, b_blocks)); |
| #endif |
| } |
| |
| // Capped/fraction: need runtime check |
| template <class D, class V = VFromD<D>, |
| hwy::EnableIf<!detail::IsFull(D())>* = nullptr> |
| HWY_API V InterleaveUpper(D d, const V a, const V b) { |
| // Less than one block: treat as capped |
| if (Lanes(d) * sizeof(TFromD<D>) < 16) { |
| const Half<decltype(d)> d2; |
| return InterleaveLower(d, UpperHalf(d2, a), UpperHalf(d2, b)); |
| } |
| return InterleaveUpper(DFromV<V>(), a, b); |
| } |
| |
| // ------------------------------ InterleaveWholeLower |
| #ifdef HWY_NATIVE_INTERLEAVE_WHOLE |
| #undef HWY_NATIVE_INTERLEAVE_WHOLE |
| #else |
| #define HWY_NATIVE_INTERLEAVE_WHOLE |
| #endif |
| |
| template <class D> |
| HWY_API VFromD<D> InterleaveWholeLower(D /*d*/, VFromD<D> a, VFromD<D> b) { |
| return detail::ZipLowerSame(a, b); |
| } |
| |
| // ------------------------------ InterleaveWholeUpper |
| |
| template <class D> |
| HWY_API VFromD<D> InterleaveWholeUpper(D d, VFromD<D> a, VFromD<D> b) { |
| if (HWY_SVE_IS_POW2 && detail::IsFull(d)) { |
| return detail::ZipUpperSame(a, b); |
| } |
| |
| const Half<decltype(d)> d2; |
| return InterleaveWholeLower(d, UpperHalf(d2, a), UpperHalf(d2, b)); |
| } |
| |
| // ------------------------------ Per4LaneBlockShuffle |
| |
| namespace detail { |
| |
| template <size_t kLaneSize, size_t kVectSize, class V, |
| HWY_IF_NOT_T_SIZE_V(V, 8)> |
| HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0x88> /*idx_3210_tag*/, |
| hwy::SizeTag<kLaneSize> /*lane_size_tag*/, |
| hwy::SizeTag<kVectSize> /*vect_size_tag*/, |
| V v) { |
| const DFromV<decltype(v)> d; |
| const RebindToUnsigned<decltype(d)> du; |
| const RepartitionToWide<decltype(du)> dw; |
| |
| const auto evens = BitCast(dw, ConcatEvenFull(v, v)); |
| return BitCast(d, ZipLowerSame(evens, evens)); |
| } |
| |
| template <size_t kLaneSize, size_t kVectSize, class V, |
| HWY_IF_NOT_T_SIZE_V(V, 8)> |
| HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0xDD> /*idx_3210_tag*/, |
| hwy::SizeTag<kLaneSize> /*lane_size_tag*/, |
| hwy::SizeTag<kVectSize> /*vect_size_tag*/, |
| V v) { |
| const DFromV<decltype(v)> d; |
| const RebindToUnsigned<decltype(d)> du; |
| const RepartitionToWide<decltype(du)> dw; |
| |
| const auto odds = BitCast(dw, ConcatOddFull(v, v)); |
| return BitCast(d, ZipLowerSame(odds, odds)); |
| } |
| |
| } // namespace detail |
| |
| // ================================================== COMBINE |
| |
| namespace detail { |
| |
| #if (HWY_TARGET == HWY_SVE_256 && HWY_HAVE_CONSTEXPR_LANES) || HWY_IDE |
| template <class D, HWY_IF_T_SIZE_D(D, 1)> |
| svbool_t MaskLowerHalf(D d) { |
| switch (MaxLanes(d)) { |
| case 32: |
| return svptrue_pat_b8(SV_VL16); |
| case 16: |
| return svptrue_pat_b8(SV_VL8); |
| case 8: |
| return svptrue_pat_b8(SV_VL4); |
| case 4: |
| return svptrue_pat_b8(SV_VL2); |
| default: |
| return svptrue_pat_b8(SV_VL1); |
| } |
| } |
| template <class D, HWY_IF_T_SIZE_D(D, 2)> |
| svbool_t MaskLowerHalf(D d) { |
| switch (MaxLanes(d)) { |
| case 16: |
| return svptrue_pat_b16(SV_VL8); |
| case 8: |
| return svptrue_pat_b16(SV_VL4); |
| case 4: |
| return svptrue_pat_b16(SV_VL2); |
| default: |
| return svptrue_pat_b16(SV_VL1); |
| } |
| } |
| template <class D, HWY_IF_T_SIZE_D(D, 4)> |
| svbool_t MaskLowerHalf(D d) { |
| switch (MaxLanes(d)) { |
| case 8: |
| return svptrue_pat_b32(SV_VL4); |
| case 4: |
| return svptrue_pat_b32(SV_VL2); |
| default: |
| return svptrue_pat_b32(SV_VL1); |
| } |
| } |
| template <class D, HWY_IF_T_SIZE_D(D, 8)> |
| svbool_t MaskLowerHalf(D d) { |
| switch (MaxLanes(d)) { |
| case 4: |
| return svptrue_pat_b64(SV_VL2); |
| default: |
| return svptrue_pat_b64(SV_VL1); |
| } |
| } |
| #endif |
| #if (HWY_TARGET == HWY_SVE2_128 && HWY_HAVE_CONSTEXPR_LANES) || HWY_IDE |
| template <class D, HWY_IF_T_SIZE_D(D, 1)> |
| svbool_t MaskLowerHalf(D d) { |
| switch (MaxLanes(d)) { |
| case 16: |
| return svptrue_pat_b8(SV_VL8); |
| case 8: |
| return svptrue_pat_b8(SV_VL4); |
| case 4: |
| return svptrue_pat_b8(SV_VL2); |
| case 2: |
| case 1: |
| default: |
| return svptrue_pat_b8(SV_VL1); |
| } |
| } |
| template <class D, HWY_IF_T_SIZE_D(D, 2)> |
| svbool_t MaskLowerHalf(D d) { |
| switch (MaxLanes(d)) { |
| case 8: |
| return svptrue_pat_b16(SV_VL4); |
| case 4: |
| return svptrue_pat_b16(SV_VL2); |
| case 2: |
| case 1: |
| default: |
| return svptrue_pat_b16(SV_VL1); |
| } |
| } |
| template <class D, HWY_IF_T_SIZE_D(D, 4)> |
| svbool_t MaskLowerHalf(D d) { |
| return svptrue_pat_b32(MaxLanes(d) == 4 ? SV_VL2 : SV_VL1); |
| } |
| template <class D, HWY_IF_T_SIZE_D(D, 8)> |
| svbool_t MaskLowerHalf(D /*d*/) { |
| return svptrue_pat_b64(SV_VL1); |
| } |
| #endif // HWY_TARGET == HWY_SVE2_128 |
| #if (HWY_TARGET != HWY_SVE_256 && HWY_TARGET != HWY_SVE2_128) || \ |
| !HWY_HAVE_CONSTEXPR_LANES |
| template <class D> |
| svbool_t MaskLowerHalf(D d) { |
| return FirstN(d, Lanes(d) / 2); |
| } |
| #endif |
| |
| template <class D> |
| svbool_t MaskUpperHalf(D d) { |
| // TODO(janwas): WHILEGE on SVE2 |
| if (HWY_SVE_IS_POW2 && IsFull(d)) { |
| return Not(MaskLowerHalf(d)); |
| } |
| |
| // For Splice to work as intended, make sure bits above Lanes(d) are zero. |
| return AndNot(MaskLowerHalf(d), detail::MakeMask(d)); |
| } |
| |
| // Right-shift vector pair by constexpr; can be used to slide down (=N) or up |
| // (=Lanes()-N). |
| #define HWY_SVE_EXT(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| template <size_t kIndex> \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| NAME(HWY_SVE_V(BASE, BITS) hi, HWY_SVE_V(BASE, BITS) lo) { \ |
| return sv##OP##_##CHAR##BITS(lo, hi, kIndex); \ |
| } |
| HWY_SVE_FOREACH(HWY_SVE_EXT, Ext, ext) |
| #undef HWY_SVE_EXT |
| |
| } // namespace detail |
| |
| // ------------------------------ ConcatUpperLower |
| template <class D, class V> |
| HWY_API V ConcatUpperLower(const D d, const V hi, const V lo) { |
| return IfThenElse(detail::MaskLowerHalf(d), lo, hi); |
| } |
| |
| // ------------------------------ ConcatLowerLower |
| template <class D, class V> |
| HWY_API V ConcatLowerLower(const D d, const V hi, const V lo) { |
| if (detail::IsFull(d)) { |
| #if defined(__ARM_FEATURE_SVE_MATMUL_FP64) && HWY_TARGET == HWY_SVE_256 |
| return detail::ConcatEvenBlocks(hi, lo); |
| #endif |
| #if HWY_TARGET == HWY_SVE2_128 |
| const Repartition<uint64_t, D> du64; |
| const auto lo64 = BitCast(du64, lo); |
| return BitCast(d, InterleaveLower(du64, lo64, BitCast(du64, hi))); |
| #endif |
| } |
| return detail::Splice(hi, lo, detail::MaskLowerHalf(d)); |
| } |
| |
| // ------------------------------ ConcatLowerUpper |
| template <class D, class V> |
| HWY_API V ConcatLowerUpper(const D d, const V hi, const V lo) { |
| #if HWY_HAVE_CONSTEXPR_LANES |
| if (detail::IsFull(d)) { |
| return detail::Ext<Lanes(d) / 2>(hi, lo); |
| } |
| #endif |
| return detail::Splice(hi, lo, detail::MaskUpperHalf(d)); |
| } |
| |
| // ------------------------------ ConcatUpperUpper |
| template <class D, class V> |
| HWY_API V ConcatUpperUpper(const D d, const V hi, const V lo) { |
| if (detail::IsFull(d)) { |
| #if defined(__ARM_FEATURE_SVE_MATMUL_FP64) && HWY_TARGET == HWY_SVE_256 |
| return detail::ConcatOddBlocks(hi, lo); |
| #endif |
| #if HWY_TARGET == HWY_SVE2_128 |
| const Repartition<uint64_t, D> du64; |
| const auto lo64 = BitCast(du64, lo); |
| return BitCast(d, InterleaveUpper(du64, lo64, BitCast(du64, hi))); |
| #endif |
| } |
| const svbool_t mask_upper = detail::MaskUpperHalf(d); |
| const V lo_upper = detail::Splice(lo, lo, mask_upper); |
| return IfThenElse(mask_upper, hi, lo_upper); |
| } |
| |
| // ------------------------------ Combine |
| template <class D, class V2> |
| HWY_API VFromD<D> Combine(const D d, const V2 hi, const V2 lo) { |
| return ConcatLowerLower(d, hi, lo); |
| } |
| |
| // ------------------------------ ZeroExtendVector |
| template <class D, class V> |
| HWY_API V ZeroExtendVector(const D d, const V lo) { |
| return Combine(d, Zero(Half<D>()), lo); |
| } |
| |
| // ------------------------------ Lower/UpperHalf |
| |
| template <class D2, class V> |
| HWY_API V LowerHalf(D2 /* tag */, const V v) { |
| return v; |
| } |
| |
| template <class V> |
| HWY_API V LowerHalf(const V v) { |
| return v; |
| } |
| |
| template <class DH, class V> |
| HWY_API V UpperHalf(const DH dh, const V v) { |
| const Twice<decltype(dh)> d; |
| // Cast so that we support bfloat16_t. |
| const RebindToUnsigned<decltype(d)> du; |
| const VFromD<decltype(du)> vu = BitCast(du, v); |
| #if HWY_HAVE_CONSTEXPR_LANES |
| return BitCast(d, detail::Ext<Lanes(dh)>(vu, vu)); |
| #else |
| const MFromD<decltype(du)> mask = detail::MaskUpperHalf(du); |
| return BitCast(d, detail::Splice(vu, vu, mask)); |
| #endif |
| } |
| |
| // ================================================== SWIZZLE |
| |
| // ------------------------------ DupEven |
| |
| namespace detail { |
| HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, InterleaveEven, trn1) |
| } // namespace detail |
| |
| template <class V> |
| HWY_API V DupEven(const V v) { |
| return detail::InterleaveEven(v, v); |
| } |
| |
| // ------------------------------ DupOdd |
| |
| namespace detail { |
| HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, InterleaveOdd, trn2) |
| } // namespace detail |
| |
| template <class V> |
| HWY_API V DupOdd(const V v) { |
| return detail::InterleaveOdd(v, v); |
| } |
| |
| // ------------------------------ OddEven |
| |
| #if HWY_SVE_HAVE_2 |
| |
| #define HWY_SVE_ODD_EVEN(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| NAME(HWY_SVE_V(BASE, BITS) odd, HWY_SVE_V(BASE, BITS) even) { \ |
| return sv##OP##_##CHAR##BITS(even, odd, /*xor=*/0); \ |
| } |
| |
| HWY_SVE_FOREACH_UI(HWY_SVE_ODD_EVEN, OddEven, eortb_n) |
| #undef HWY_SVE_ODD_EVEN |
| |
| template <class V, HWY_IF_FLOAT_V(V)> |
| HWY_API V OddEven(const V odd, const V even) { |
| const DFromV<V> d; |
| const RebindToUnsigned<decltype(d)> du; |
| return BitCast(d, OddEven(BitCast(du, odd), BitCast(du, even))); |
| } |
| |
| #else |
| |
| template <class V> |
| HWY_API V OddEven(const V odd, const V even) { |
| const auto odd_in_even = detail::Ext<1>(odd, odd); |
| return detail::InterleaveEven(even, odd_in_even); |
| } |
| |
| #endif // HWY_TARGET |
| |
| // ------------------------------ InterleaveEven |
| template <class D> |
| HWY_API VFromD<D> InterleaveEven(D /*d*/, VFromD<D> a, VFromD<D> b) { |
| return detail::InterleaveEven(a, b); |
| } |
| |
| // ------------------------------ InterleaveOdd |
| template <class D> |
| HWY_API VFromD<D> InterleaveOdd(D /*d*/, VFromD<D> a, VFromD<D> b) { |
| return detail::InterleaveOdd(a, b); |
| } |
| |
| // ------------------------------ OddEvenBlocks |
| template <class V> |
| HWY_API V OddEvenBlocks(const V odd, const V even) { |
| const DFromV<V> d; |
| #if HWY_TARGET == HWY_SVE_256 |
| return ConcatUpperLower(d, odd, even); |
| #elif HWY_TARGET == HWY_SVE2_128 |
| (void)odd; |
| (void)d; |
| return even; |
| #else |
| const RebindToUnsigned<decltype(d)> du; |
| using TU = TFromD<decltype(du)>; |
| constexpr size_t kShift = CeilLog2(16 / sizeof(TU)); |
| const auto idx_block = ShiftRight<kShift>(Iota(du, 0)); |
| const auto lsb = detail::AndN(idx_block, static_cast<TU>(1)); |
| const svbool_t is_even = detail::EqN(lsb, static_cast<TU>(0)); |
| return IfThenElse(is_even, even, odd); |
| #endif |
| } |
| |
| // ------------------------------ TableLookupLanes |
| |
| template <class D, class VI> |
| HWY_API VFromD<RebindToUnsigned<D>> IndicesFromVec(D d, VI vec) { |
| using TI = TFromV<VI>; |
| static_assert(sizeof(TFromD<D>) == sizeof(TI), "Index/lane size mismatch"); |
| const RebindToUnsigned<D> du; |
| const auto indices = BitCast(du, vec); |
| #if HWY_IS_DEBUG_BUILD |
| using TU = MakeUnsigned<TI>; |
| const size_t twice_max_lanes = Lanes(d) * 2; |
| HWY_DASSERT(AllTrue( |
| du, Eq(indices, |
| detail::AndN(indices, static_cast<TU>(twice_max_lanes - 1))))); |
| #else |
| (void)d; |
| #endif |
| return indices; |
| } |
| |
| template <class D, typename TI> |
| HWY_API VFromD<RebindToUnsigned<D>> SetTableIndices(D d, const TI* idx) { |
| static_assert(sizeof(TFromD<D>) == sizeof(TI), "Index size must match lane"); |
| return IndicesFromVec(d, LoadU(Rebind<TI, D>(), idx)); |
| } |
| |
| #define HWY_SVE_TABLE(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_V(uint, BITS) idx) { \ |
| return sv##OP##_##CHAR##BITS(v, idx); \ |
| } |
| |
| HWY_SVE_FOREACH(HWY_SVE_TABLE, TableLookupLanes, tbl) |
| #if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC |
| HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_TABLE, TableLookupLanes, tbl) |
| #endif |
| #undef HWY_SVE_TABLE |
| |
| #if HWY_SVE_HAVE_2 |
| namespace detail { |
| #define HWY_SVE_TABLE2(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| NAME(HWY_SVE_TUPLE(BASE, BITS, 2) tuple, HWY_SVE_V(uint, BITS) idx) { \ |
| return sv##OP##_##CHAR##BITS(tuple, idx); \ |
| } |
| |
| HWY_SVE_FOREACH(HWY_SVE_TABLE2, NativeTwoTableLookupLanes, tbl2) |
| #if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC |
| HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_TABLE2, NativeTwoTableLookupLanes, |
| tbl2) |
| #endif |
| #undef HWY_SVE_TABLE |
| } // namespace detail |
| #endif // HWY_SVE_HAVE_2 |
| |
| template <class D> |
| HWY_API VFromD<D> TwoTablesLookupLanes(D d, VFromD<D> a, VFromD<D> b, |
| VFromD<RebindToUnsigned<D>> idx) { |
| // SVE2 has an instruction for this, but it only works for full 2^n vectors. |
| #if HWY_SVE_HAVE_2 && HWY_SVE_IS_POW2 |
| if (detail::IsFull(d)) { |
| return detail::NativeTwoTableLookupLanes(Create2(d, a, b), idx); |
| } |
| #endif |
| const RebindToUnsigned<decltype(d)> du; |
| using TU = TFromD<decltype(du)>; |
| |
| const size_t num_of_lanes = Lanes(d); |
| const auto idx_mod = detail::AndN(idx, static_cast<TU>(num_of_lanes - 1)); |
| const auto sel_a_mask = Eq(idx, idx_mod); |
| |
| const auto a_lookup_result = TableLookupLanes(a, idx_mod); |
| const auto b_lookup_result = TableLookupLanes(b, idx_mod); |
| return IfThenElse(sel_a_mask, a_lookup_result, b_lookup_result); |
| } |
| |
| template <class V> |
| HWY_API V TwoTablesLookupLanes(V a, V b, |
| VFromD<RebindToUnsigned<DFromV<V>>> idx) { |
| const DFromV<decltype(a)> d; |
| return TwoTablesLookupLanes(d, a, b, idx); |
| } |
| |
| // ------------------------------ SlideUpLanes (FirstN) |
| template <class D> |
| HWY_API VFromD<D> SlideUpLanes(D d, VFromD<D> v, size_t amt) { |
| return detail::Splice(v, Zero(d), FirstN(d, amt)); |
| } |
| |
| // ------------------------------ Slide1Up |
| |
| #ifdef HWY_NATIVE_SLIDE1_UP_DOWN |
| #undef HWY_NATIVE_SLIDE1_UP_DOWN |
| #else |
| #define HWY_NATIVE_SLIDE1_UP_DOWN |
| #endif |
| |
| template <class D> |
| HWY_API VFromD<D> Slide1Up(D d, VFromD<D> v) { |
| return SlideUpLanes(d, v, 1); |
| } |
| |
| // ------------------------------ SlideDownLanes (TableLookupLanes) |
| template <class D> |
| HWY_API VFromD<D> SlideDownLanes(D d, VFromD<D> v, size_t amt) { |
| const RebindToUnsigned<decltype(d)> du; |
| using TU = TFromD<decltype(du)>; |
| const auto idx = Iota(du, static_cast<TU>(amt)); |
| return IfThenElseZero(FirstN(d, Lanes(d) - amt), TableLookupLanes(v, idx)); |
| } |
| |
| // ------------------------------ Slide1Down |
| template <class D> |
| HWY_API VFromD<D> Slide1Down(D d, VFromD<D> v) { |
| return SlideDownLanes(d, v, 1); |
| } |
| |
| // ------------------------------ SwapAdjacentBlocks (TableLookupLanes) |
| |
| namespace detail { |
| |
| template <typename T, size_t N, int kPow2> |
| constexpr size_t LanesPerBlock(Simd<T, N, kPow2> d) { |
| // We might have a capped vector smaller than a block, so honor that. |
| return HWY_MIN(16 / sizeof(T), MaxLanes(d)); |
| } |
| |
| } // namespace detail |
| |
| template <class V> |
| HWY_API V SwapAdjacentBlocks(const V v) { |
| const DFromV<V> d; |
| #if HWY_TARGET == HWY_SVE_256 |
| return ConcatLowerUpper(d, v, v); |
| #elif HWY_TARGET == HWY_SVE2_128 |
| (void)d; |
| return v; |
| #else |
| const RebindToUnsigned<decltype(d)> du; |
| constexpr auto kLanesPerBlock = |
| static_cast<TFromD<decltype(du)>>(detail::LanesPerBlock(d)); |
| const VFromD<decltype(du)> idx = detail::XorN(Iota(du, 0), kLanesPerBlock); |
| return TableLookupLanes(v, idx); |
| #endif |
| } |
| |
| // ------------------------------ InterleaveEvenBlocks |
| // (ConcatLowerLower, SlideUpLanes, OddEvenBlocks) |
| |
| template <class D, class V = VFromD<D>> |
| HWY_API V InterleaveEvenBlocks(D d, V a, V b) { |
| #if HWY_TARGET == HWY_SVE_256 |
| return ConcatLowerLower(d, b, a); |
| #elif HWY_TARGET == HWY_SVE2_128 |
| (void)d; |
| (void)b; |
| return a; |
| #else |
| constexpr size_t kLanesPerBlock = detail::LanesPerBlock(d); |
| return OddEvenBlocks(SlideUpLanes(d, b, kLanesPerBlock), a); |
| #endif |
| } |
| |
| // ------------------------------ InterleaveOddBlocks |
| // (ConcatUpperUpper, SlideDownLanes, OddEvenBlocks) |
| |
| template <class D, class V = VFromD<D>> |
| HWY_API V InterleaveOddBlocks(D d, V a, V b) { |
| #if HWY_TARGET == HWY_SVE_256 |
| return ConcatUpperUpper(d, b, a); |
| #elif HWY_TARGET == HWY_SVE2_128 |
| (void)d; |
| (void)b; |
| return a; |
| #else |
| constexpr size_t kLanesPerBlock = detail::LanesPerBlock(d); |
| return OddEvenBlocks(b, SlideDownLanes(d, a, kLanesPerBlock)); |
| #endif |
| } |
| |
| // ------------------------------ Reverse |
| |
| namespace detail { |
| |
| #define HWY_SVE_REVERSE(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ |
| return sv##OP##_##CHAR##BITS(v); \ |
| } |
| |
| HWY_SVE_FOREACH(HWY_SVE_REVERSE, ReverseFull, rev) |
| #if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC |
| HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_REVERSE, ReverseFull, rev) |
| #endif |
| #undef HWY_SVE_REVERSE |
| |
| } // namespace detail |
| |
| template <class D, class V> |
| HWY_API V Reverse(D d, V v) { |
| using T = TFromD<D>; |
| const auto reversed = detail::ReverseFull(v); |
| if (HWY_SVE_IS_POW2 && detail::IsFull(d)) return reversed; |
| // Shift right to remove extra (non-pow2 and remainder) lanes. |
| // TODO(janwas): on SVE2, use WHILEGE. |
| // Avoids FirstN truncating to the return vector size. Must also avoid Not |
| // because that is limited to SV_POW2. |
| const ScalableTag<T> dfull; |
| const svbool_t all_true = detail::AllPTrue(dfull); |
| const size_t all_lanes = detail::AllHardwareLanes<T>(); |
| const size_t want_lanes = Lanes(d); |
| HWY_DASSERT(want_lanes <= all_lanes); |
| const svbool_t mask = |
| svnot_b_z(all_true, FirstN(dfull, all_lanes - want_lanes)); |
| return detail::Splice(reversed, reversed, mask); |
| } |
| |
| // ------------------------------ Reverse2 |
| |
| // Per-target flag to prevent generic_ops-inl.h defining 8-bit Reverse2/4/8. |
| #ifdef HWY_NATIVE_REVERSE2_8 |
| #undef HWY_NATIVE_REVERSE2_8 |
| #else |
| #define HWY_NATIVE_REVERSE2_8 |
| #endif |
| |
| template <class D, HWY_IF_T_SIZE_D(D, 1)> |
| HWY_API VFromD<D> Reverse2(D d, const VFromD<D> v) { |
| const RebindToUnsigned<decltype(d)> du; |
| const RepartitionToWide<decltype(du)> dw; |
| return BitCast(d, svrevb_u16_x(detail::PTrue(d), BitCast(dw, v))); |
| } |
| |
| template <class D, HWY_IF_T_SIZE_D(D, 2)> |
| HWY_API VFromD<D> Reverse2(D d, const VFromD<D> v) { |
| const RebindToUnsigned<decltype(d)> du; |
| const RepartitionToWide<decltype(du)> dw; |
| return BitCast(d, svrevh_u32_x(detail::PTrue(d), BitCast(dw, v))); |
| } |
| |
| template <class D, HWY_IF_T_SIZE_D(D, 4)> |
| HWY_API VFromD<D> Reverse2(D d, const VFromD<D> v) { |
| const RebindToUnsigned<decltype(d)> du; |
| const RepartitionToWide<decltype(du)> dw; |
| return BitCast(d, svrevw_u64_x(detail::PTrue(d), BitCast(dw, v))); |
| } |
| |
| template <class D, HWY_IF_T_SIZE_D(D, 8)> |
| HWY_API VFromD<D> Reverse2(D d, const VFromD<D> v) { // 3210 |
| #if HWY_TARGET == HWY_SVE2_128 |
| if (detail::IsFull(d)) { |
| return detail::Ext<1>(v, v); |
| } |
| #endif |
| (void)d; |
| const auto odd_in_even = detail::Ext<1>(v, v); // x321 |
| return detail::InterleaveEven(odd_in_even, v); // 2301 |
| } |
| |
| // ------------------------------ Reverse4 (TableLookupLanes) |
| |
| template <class D, HWY_IF_T_SIZE_D(D, 1)> |
| HWY_API VFromD<D> Reverse4(D d, const VFromD<D> v) { |
| const RebindToUnsigned<decltype(d)> du; |
| const RepartitionToWideX2<decltype(du)> du32; |
| return BitCast(d, svrevb_u32_x(detail::PTrue(d), BitCast(du32, v))); |
| } |
| |
| template <class D, HWY_IF_T_SIZE_D(D, 2)> |
| HWY_API VFromD<D> Reverse4(D d, const VFromD<D> v) { |
| const RebindToUnsigned<decltype(d)> du; |
| const RepartitionToWideX2<decltype(du)> du64; |
| return BitCast(d, svrevh_u64_x(detail::PTrue(d), BitCast(du64, v))); |
| } |
| |
| template <class D, HWY_IF_T_SIZE_D(D, 4)> |
| HWY_API VFromD<D> Reverse4(D d, const VFromD<D> v) { |
| if (HWY_TARGET == HWY_SVE2_128 && detail::IsFull(d)) { |
| return detail::ReverseFull(v); |
| } |
| // TODO(janwas): is this approach faster than Shuffle0123? |
| const RebindToUnsigned<decltype(d)> du; |
| const auto idx = detail::XorN(Iota(du, 0), 3); |
| return TableLookupLanes(v, idx); |
| } |
| |
| template <class D, HWY_IF_T_SIZE_D(D, 8)> |
| HWY_API VFromD<D> Reverse4(D d, const VFromD<D> v) { |
| if (HWY_TARGET == HWY_SVE_256 && detail::IsFull(d)) { |
| return detail::ReverseFull(v); |
| } |
| // TODO(janwas): is this approach faster than Shuffle0123? |
| const RebindToUnsigned<decltype(d)> du; |
| const auto idx = detail::XorN(Iota(du, 0), 3); |
| return TableLookupLanes(v, idx); |
| } |
| |
| // ------------------------------ Reverse8 (TableLookupLanes) |
| |
| template <class D, HWY_IF_T_SIZE_D(D, 1)> |
| HWY_API VFromD<D> Reverse8(D d, const VFromD<D> v) { |
| const Repartition<uint64_t, decltype(d)> du64; |
| return BitCast(d, svrevb_u64_x(detail::PTrue(d), BitCast(du64, v))); |
| } |
| |
| template <class D, HWY_IF_NOT_T_SIZE_D(D, 1)> |
| HWY_API VFromD<D> Reverse8(D d, const VFromD<D> v) { |
| const RebindToUnsigned<decltype(d)> du; |
| const auto idx = detail::XorN(Iota(du, 0), 7); |
| return TableLookupLanes(v, idx); |
| } |
| |
| // ------------------------------- ReverseBits |
| |
| #ifdef HWY_NATIVE_REVERSE_BITS_UI8 |
| #undef HWY_NATIVE_REVERSE_BITS_UI8 |
| #else |
| #define HWY_NATIVE_REVERSE_BITS_UI8 |
| #endif |
| |
| #ifdef HWY_NATIVE_REVERSE_BITS_UI16_32_64 |
| #undef HWY_NATIVE_REVERSE_BITS_UI16_32_64 |
| #else |
| #define HWY_NATIVE_REVERSE_BITS_UI16_32_64 |
| #endif |
| |
| #define HWY_SVE_REVERSE_BITS(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ |
| const DFromV<decltype(v)> d; \ |
| return sv##OP##_##CHAR##BITS##_x(detail::PTrue(d), v); \ |
| } |
| |
| HWY_SVE_FOREACH_UI(HWY_SVE_REVERSE_BITS, ReverseBits, rbit) |
| #undef HWY_SVE_REVERSE_BITS |
| |
| // ------------------------------ Block insert/extract/broadcast ops |
| #if HWY_TARGET != HWY_SVE2_128 |
| |
| #ifdef HWY_NATIVE_BLK_INSERT_EXTRACT |
| #undef HWY_NATIVE_BLK_INSERT_EXTRACT |
| #else |
| #define HWY_NATIVE_BLK_INSERT_EXTRACT |
| #endif |
| |
| template <int kBlockIdx, class V> |
| HWY_API V InsertBlock(V v, V blk_to_insert) { |
| const DFromV<decltype(v)> d; |
| static_assert(0 <= kBlockIdx && kBlockIdx < d.MaxBlocks(), |
| "Invalid block index"); |
| |
| #if HWY_TARGET == HWY_SVE_256 |
| return (kBlockIdx == 0) ? ConcatUpperLower(d, v, blk_to_insert) |
| : ConcatLowerLower(d, blk_to_insert, v); |
| #else |
| constexpr size_t kLanesPerBlock = detail::LanesPerBlock(d); |
| |
| constexpr size_t kBlockOffset = |
| static_cast<size_t>(kBlockIdx) * kLanesPerBlock; |
| const auto splice_mask = FirstN(d, kBlockOffset); |
| const auto sel_lo_mask = FirstN(d, kBlockOffset + kLanesPerBlock); |
| |
| const auto splice_result = detail::Splice(blk_to_insert, v, splice_mask); |
| return IfThenElse(sel_lo_mask, splice_result, v); |
| #endif |
| } |
| |
| template <int kBlockIdx, class V> |
| HWY_API V ExtractBlock(V v) { |
| const DFromV<decltype(v)> d; |
| static_assert(0 <= kBlockIdx && kBlockIdx < d.MaxBlocks(), |
| "Invalid block index"); |
| |
| if (kBlockIdx == 0) return v; |
| |
| #if HWY_TARGET == HWY_SVE_256 |
| return UpperHalf(Half<decltype(d)>(), v); |
| #else |
| const RebindToUnsigned<decltype(d)> du; |
| using TU = TFromD<decltype(du)>; |
| constexpr size_t kLanesPerBlock = detail::LanesPerBlock(d); |
| constexpr size_t kBlockOffset = |
| static_cast<size_t>(kBlockIdx) * kLanesPerBlock; |
| const auto splice_mask = |
| RebindMask(d, detail::LtN(Iota(du, static_cast<TU>(0u - kBlockOffset)), |
| static_cast<TU>(kLanesPerBlock))); |
| return detail::Splice(v, v, splice_mask); |
| #endif |
| } |
| |
| template <int kBlockIdx, class V> |
| HWY_API V BroadcastBlock(V v) { |
| const DFromV<decltype(v)> d; |
| static_assert(0 <= kBlockIdx && kBlockIdx < d.MaxBlocks(), |
| "Invalid block index"); |
| |
| const RebindToUnsigned<decltype(d)> du; // for bfloat16_t |
| using VU = VFromD<decltype(du)>; |
| const VU vu = BitCast(du, v); |
| |
| #if HWY_TARGET == HWY_SVE_256 |
| return BitCast(d, (kBlockIdx == 0) ? ConcatLowerLower(du, vu, vu) |
| : ConcatUpperUpper(du, vu, vu)); |
| #else |
| using TU = TFromD<decltype(du)>; |
| constexpr size_t kLanesPerBlock = detail::LanesPerBlock(d); |
| constexpr size_t kBlockOffset = |
| static_cast<size_t>(kBlockIdx) * kLanesPerBlock; |
| |
| const VU idx = detail::AddN( |
| detail::AndN(Iota(du, TU{0}), static_cast<TU>(kLanesPerBlock - 1)), |
| static_cast<TU>(kBlockOffset)); |
| return BitCast(d, TableLookupLanes(vu, idx)); |
| #endif |
| } |
| |
| #endif // HWY_TARGET != HWY_SVE2_128 |
| |
| // ------------------------------ Compress (PromoteTo) |
| |
| template <typename T> |
| struct CompressIsPartition { |
| #if HWY_TARGET == HWY_SVE_256 || HWY_TARGET == HWY_SVE2_128 |
| // Optimization for 64-bit lanes (could also be applied to 32-bit, but that |
| // requires a larger table). |
| enum { value = (sizeof(T) == 8) }; |
| #else |
| enum { value = 0 }; |
| #endif // HWY_TARGET == HWY_SVE_256 |
| }; |
| |
| #define HWY_SVE_COMPRESS(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v, svbool_t mask) { \ |
| return sv##OP##_##CHAR##BITS(mask, v); \ |
| } |
| |
| #if HWY_TARGET == HWY_SVE_256 || HWY_TARGET == HWY_SVE2_128 |
| HWY_SVE_FOREACH_UI32(HWY_SVE_COMPRESS, Compress, compact) |
| HWY_SVE_FOREACH_F32(HWY_SVE_COMPRESS, Compress, compact) |
| #else |
| HWY_SVE_FOREACH_UIF3264(HWY_SVE_COMPRESS, Compress, compact) |
| #endif |
| #undef HWY_SVE_COMPRESS |
| |
| #if HWY_TARGET == HWY_SVE_256 || HWY_IDE |
| template <class V, HWY_IF_T_SIZE_V(V, 8)> |
| HWY_API V Compress(V v, svbool_t mask) { |
| const DFromV<V> d; |
| const RebindToUnsigned<decltype(d)> du64; |
| |
| // Convert mask into bitfield via horizontal sum (faster than ORV) of masked |
| // bits 1, 2, 4, 8. Pre-multiply by N so we can use it as an offset for |
| // SetTableIndices. |
| const svuint64_t bits = Shl(Set(du64, 1), Iota(du64, 2)); |
| const size_t offset = detail::SumOfLanesM(mask, bits); |
| |
| // See CompressIsPartition. |
| alignas(16) static constexpr uint64_t table[4 * 16] = { |
| // PrintCompress64x4Tables |
| 0, 1, 2, 3, 0, 1, 2, 3, 1, 0, 2, 3, 0, 1, 2, 3, 2, 0, 1, 3, 0, 2, |
| 1, 3, 1, 2, 0, 3, 0, 1, 2, 3, 3, 0, 1, 2, 0, 3, 1, 2, 1, 3, 0, 2, |
| 0, 1, 3, 2, 2, 3, 0, 1, 0, 2, 3, 1, 1, 2, 3, 0, 0, 1, 2, 3}; |
| return TableLookupLanes(v, SetTableIndices(d, table + offset)); |
| } |
| |
| #endif // HWY_TARGET == HWY_SVE_256 |
| #if HWY_TARGET == HWY_SVE2_128 || HWY_IDE |
| template <class V, HWY_IF_T_SIZE_V(V, 8)> |
| HWY_API V Compress(V v, svbool_t mask) { |
| // If mask == 10: swap via splice. A mask of 00 or 11 leaves v unchanged, 10 |
| // swaps upper/lower (the lower half is set to the upper half, and the |
| // remaining upper half is filled from the lower half of the second v), and |
| // 01 is invalid because it would ConcatLowerLower. zip1 and AndNot keep 10 |
| // unchanged and map everything else to 00. |
| const svbool_t maskLL = svzip1_b64(mask, mask); // broadcast lower lane |
| return detail::Splice(v, v, AndNot(maskLL, mask)); |
| } |
| |
| #endif // HWY_TARGET == HWY_SVE2_128 |
| |
| template <class V, HWY_IF_T_SIZE_V(V, 2)> |
| HWY_API V Compress(V v, svbool_t mask16) { |
| static_assert(!IsSame<V, svfloat16_t>(), "Must use overload"); |
| const DFromV<V> d16; |
| |
| // Promote vector and mask to 32-bit |
| const RepartitionToWide<decltype(d16)> dw; |
| const auto v32L = PromoteTo(dw, v); |
| const auto v32H = detail::PromoteUpperTo(dw, v); |
| const svbool_t mask32L = svunpklo_b(mask16); |
| const svbool_t mask32H = svunpkhi_b(mask16); |
| |
| const auto compressedL = Compress(v32L, mask32L); |
| const auto compressedH = Compress(v32H, mask32H); |
| |
| // Demote to 16-bit (already in range) - separately so we can splice |
| const V evenL = BitCast(d16, compressedL); |
| const V evenH = BitCast(d16, compressedH); |
| const V v16L = detail::ConcatEvenFull(evenL, evenL); // lower half |
| const V v16H = detail::ConcatEvenFull(evenH, evenH); |
| |
| // We need to combine two vectors of non-constexpr length, so the only option |
| // is Splice, which requires us to synthesize a mask. NOTE: this function uses |
| // full vectors (SV_ALL instead of SV_POW2), hence we need unmasked svcnt. |
| const size_t countL = detail::CountTrueFull(dw, mask32L); |
| const auto compressed_maskL = FirstN(d16, countL); |
| return detail::Splice(v16H, v16L, compressed_maskL); |
| } |
| |
| // Must treat float16_t as integers so we can ConcatEven. |
| HWY_API svfloat16_t Compress(svfloat16_t v, svbool_t mask16) { |
| const DFromV<decltype(v)> df; |
| const RebindToSigned<decltype(df)> di; |
| return BitCast(df, Compress(BitCast(di, v), mask16)); |
| } |
| |
| // ------------------------------ CompressNot |
| |
| // 2 or 4 bytes |
| template <class V, HWY_IF_T_SIZE_ONE_OF_V(V, (1 << 2) | (1 << 4))> |
| HWY_API V CompressNot(V v, const svbool_t mask) { |
| return Compress(v, Not(mask)); |
| } |
| |
| template <class V, HWY_IF_T_SIZE_V(V, 8)> |
| HWY_API V CompressNot(V v, svbool_t mask) { |
| #if HWY_TARGET == HWY_SVE2_128 || HWY_IDE |
| // If mask == 01: swap via splice. A mask of 00 or 11 leaves v unchanged, 10 |
| // swaps upper/lower (the lower half is set to the upper half, and the |
| // remaining upper half is filled from the lower half of the second v), and |
| // 01 is invalid because it would ConcatLowerLower. zip1 and AndNot map |
| // 01 to 10, and everything else to 00. |
| const svbool_t maskLL = svzip1_b64(mask, mask); // broadcast lower lane |
| return detail::Splice(v, v, AndNot(mask, maskLL)); |
| #endif |
| #if HWY_TARGET == HWY_SVE_256 || HWY_IDE |
| const DFromV<V> d; |
| const RebindToUnsigned<decltype(d)> du64; |
| |
| // Convert mask into bitfield via horizontal sum (faster than ORV) of masked |
| // bits 1, 2, 4, 8. Pre-multiply by N so we can use it as an offset for |
| // SetTableIndices. |
| const svuint64_t bits = Shl(Set(du64, 1), Iota(du64, 2)); |
| const size_t offset = detail::SumOfLanesM(mask, bits); |
| |
| // See CompressIsPartition. |
| alignas(16) static constexpr uint64_t table[4 * 16] = { |
| // PrintCompressNot64x4Tables |
| 0, 1, 2, 3, 1, 2, 3, 0, 0, 2, 3, 1, 2, 3, 0, 1, 0, 1, 3, 2, 1, 3, |
| 0, 2, 0, 3, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 1, 2, 0, 3, 0, 2, 1, 3, |
| 2, 0, 1, 3, 0, 1, 2, 3, 1, 0, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3}; |
| return TableLookupLanes(v, SetTableIndices(d, table + offset)); |
| #endif // HWY_TARGET == HWY_SVE_256 |
| |
| return Compress(v, Not(mask)); |
| } |
| |
| // ------------------------------ CompressBlocksNot |
| HWY_API svuint64_t CompressBlocksNot(svuint64_t v, svbool_t mask) { |
| #if HWY_TARGET == HWY_SVE2_128 |
| (void)mask; |
| return v; |
| #endif |
| #if HWY_TARGET == HWY_SVE_256 || HWY_IDE |
| uint64_t bits = 0; // predicate reg is 32-bit |
| CopyBytes<4>(&mask, &bits); // not same size - 64-bit more efficient |
| // Concatenate LSB for upper and lower blocks, pre-scale by 4 for table idx. |
| const size_t offset = ((bits & 1) ? 4u : 0u) + ((bits & 0x10000) ? 8u : 0u); |
| // See CompressIsPartition. Manually generated; flip halves if mask = [0, 1]. |
| alignas(16) static constexpr uint64_t table[4 * 4] = {0, 1, 2, 3, 2, 3, 0, 1, |
| 0, 1, 2, 3, 0, 1, 2, 3}; |
| const ScalableTag<uint64_t> d; |
| return TableLookupLanes(v, SetTableIndices(d, table + offset)); |
| #endif |
| |
| return CompressNot(v, mask); |
| } |
| |
| // ------------------------------ CompressStore |
| template <class V, class D, HWY_IF_NOT_T_SIZE_D(D, 1)> |
| HWY_API size_t CompressStore(const V v, const svbool_t mask, const D d, |
| TFromD<D>* HWY_RESTRICT unaligned) { |
| StoreU(Compress(v, mask), d, unaligned); |
| return CountTrue(d, mask); |
| } |
| |
| // ------------------------------ CompressBlendedStore |
| template <class V, class D, HWY_IF_NOT_T_SIZE_D(D, 1)> |
| HWY_API size_t CompressBlendedStore(const V v, const svbool_t mask, const D d, |
| TFromD<D>* HWY_RESTRICT unaligned) { |
| const size_t count = CountTrue(d, mask); |
| const svbool_t store_mask = FirstN(d, count); |
| BlendedStore(Compress(v, mask), store_mask, d, unaligned); |
| return count; |
| } |
| |
| // ================================================== MASK (2) |
| |
| // ------------------------------ FindKnownLastTrue |
| template <class D> |
| HWY_API size_t FindKnownLastTrue(D d, svbool_t m) { |
| const RebindToUnsigned<decltype(d)> du; |
| return static_cast<size_t>(detail::ExtractLastMatchingLaneM( |
| Iota(du, 0), And(m, detail::MakeMask(d)))); |
| } |
| |
| // ------------------------------ FindLastTrue |
| template <class D> |
| HWY_API intptr_t FindLastTrue(D d, svbool_t m) { |
| return AllFalse(d, m) ? intptr_t{-1} |
| : static_cast<intptr_t>(FindKnownLastTrue(d, m)); |
| } |
| |
| // ================================================== BLOCKWISE |
| |
| // ------------------------------ CombineShiftRightBytes |
| |
| // Prevent accidentally using these for 128-bit vectors - should not be |
| // necessary. |
| #if HWY_TARGET != HWY_SVE2_128 |
| namespace detail { |
| |
| // For x86-compatible behaviour mandated by Highway API: TableLookupBytes |
| // offsets are implicitly relative to the start of their 128-bit block. |
| template <class D, class V> |
| HWY_INLINE V OffsetsOf128BitBlocks(const D d, const V iota0) { |
| using T = MakeUnsigned<TFromD<D>>; |
| return detail::AndNotN(static_cast<T>(LanesPerBlock(d) - 1), iota0); |
| } |
| |
| template <size_t kLanes, class D, HWY_IF_T_SIZE_D(D, 1)> |
| svbool_t FirstNPerBlock(D d) { |
| const RebindToUnsigned<decltype(d)> du; |
| constexpr size_t kLanesPerBlock = detail::LanesPerBlock(du); |
| const svuint8_t idx_mod = |
| svdupq_n_u8(0 % kLanesPerBlock, 1 % kLanesPerBlock, 2 % kLanesPerBlock, |
| 3 % kLanesPerBlock, 4 % kLanesPerBlock, 5 % kLanesPerBlock, |
| 6 % kLanesPerBlock, 7 % kLanesPerBlock, 8 % kLanesPerBlock, |
| 9 % kLanesPerBlock, 10 % kLanesPerBlock, 11 % kLanesPerBlock, |
| 12 % kLanesPerBlock, 13 % kLanesPerBlock, 14 % kLanesPerBlock, |
| 15 % kLanesPerBlock); |
| return detail::LtN(BitCast(du, idx_mod), kLanes); |
| } |
| template <size_t kLanes, class D, HWY_IF_T_SIZE_D(D, 2)> |
| svbool_t FirstNPerBlock(D d) { |
| const RebindToUnsigned<decltype(d)> du; |
| constexpr size_t kLanesPerBlock = detail::LanesPerBlock(du); |
| const svuint16_t idx_mod = |
| svdupq_n_u16(0 % kLanesPerBlock, 1 % kLanesPerBlock, 2 % kLanesPerBlock, |
| 3 % kLanesPerBlock, 4 % kLanesPerBlock, 5 % kLanesPerBlock, |
| 6 % kLanesPerBlock, 7 % kLanesPerBlock); |
| return detail::LtN(BitCast(du, idx_mod), kLanes); |
| } |
| template <size_t kLanes, class D, HWY_IF_T_SIZE_D(D, 4)> |
| svbool_t FirstNPerBlock(D d) { |
| const RebindToUnsigned<decltype(d)> du; |
| constexpr size_t kLanesPerBlock = detail::LanesPerBlock(du); |
| const svuint32_t idx_mod = |
| svdupq_n_u32(0 % kLanesPerBlock, 1 % kLanesPerBlock, 2 % kLanesPerBlock, |
| 3 % kLanesPerBlock); |
| return detail::LtN(BitCast(du, idx_mod), kLanes); |
| } |
| template <size_t kLanes, class D, HWY_IF_T_SIZE_D(D, 8)> |
| svbool_t FirstNPerBlock(D d) { |
| const RebindToUnsigned<decltype(d)> du; |
| constexpr size_t kLanesPerBlock = detail::LanesPerBlock(du); |
| const svuint64_t idx_mod = |
| svdupq_n_u64(0 % kLanesPerBlock, 1 % kLanesPerBlock); |
| return detail::LtN(BitCast(du, idx_mod), kLanes); |
| } |
| |
| } // namespace detail |
| #endif // HWY_TARGET != HWY_SVE2_128 |
| |
| template <size_t kBytes, class D, class V = VFromD<D>> |
| HWY_API V CombineShiftRightBytes(const D d, const V hi, const V lo) { |
| const Repartition<uint8_t, decltype(d)> d8; |
| const auto hi8 = BitCast(d8, hi); |
| const auto lo8 = BitCast(d8, lo); |
| #if HWY_TARGET == HWY_SVE2_128 |
| return BitCast(d, detail::Ext<kBytes>(hi8, lo8)); |
| #else |
| const auto hi_up = detail::Splice(hi8, hi8, FirstN(d8, 16 - kBytes)); |
| const auto lo_down = detail::Ext<kBytes>(lo8, lo8); |
| const svbool_t is_lo = detail::FirstNPerBlock<16 - kBytes>(d8); |
| return BitCast(d, IfThenElse(is_lo, lo_down, hi_up)); |
| #endif |
| } |
| |
| // ------------------------------ Shuffle2301 |
| template <class V> |
| HWY_API V Shuffle2301(const V v) { |
| const DFromV<V> d; |
| static_assert(sizeof(TFromD<decltype(d)>) == 4, "Defined for 32-bit types"); |
| return Reverse2(d, v); |
| } |
| |
| // ------------------------------ Shuffle2103 |
| template <class V> |
| HWY_API V Shuffle2103(const V v) { |
| const DFromV<V> d; |
| const Repartition<uint8_t, decltype(d)> d8; |
| static_assert(sizeof(TFromD<decltype(d)>) == 4, "Defined for 32-bit types"); |
| const svuint8_t v8 = BitCast(d8, v); |
| return BitCast(d, CombineShiftRightBytes<12>(d8, v8, v8)); |
| } |
| |
| // ------------------------------ Shuffle0321 |
| template <class V> |
| HWY_API V Shuffle0321(const V v) { |
| const DFromV<V> d; |
| const Repartition<uint8_t, decltype(d)> d8; |
| static_assert(sizeof(TFromD<decltype(d)>) == 4, "Defined for 32-bit types"); |
| const svuint8_t v8 = BitCast(d8, v); |
| return BitCast(d, CombineShiftRightBytes<4>(d8, v8, v8)); |
| } |
| |
| // ------------------------------ Shuffle1032 |
| template <class V> |
| HWY_API V Shuffle1032(const V v) { |
| const DFromV<V> d; |
| const Repartition<uint8_t, decltype(d)> d8; |
| static_assert(sizeof(TFromD<decltype(d)>) == 4, "Defined for 32-bit types"); |
| const svuint8_t v8 = BitCast(d8, v); |
| return BitCast(d, CombineShiftRightBytes<8>(d8, v8, v8)); |
| } |
| |
| // ------------------------------ Shuffle01 |
| template <class V> |
| HWY_API V Shuffle01(const V v) { |
| const DFromV<V> d; |
| const Repartition<uint8_t, decltype(d)> d8; |
| static_assert(sizeof(TFromD<decltype(d)>) == 8, "Defined for 64-bit types"); |
| const svuint8_t v8 = BitCast(d8, v); |
| return BitCast(d, CombineShiftRightBytes<8>(d8, v8, v8)); |
| } |
| |
| // ------------------------------ Shuffle0123 |
| template <class V> |
| HWY_API V Shuffle0123(const V v) { |
| return Shuffle2301(Shuffle1032(v)); |
| } |
| |
| // ------------------------------ ReverseBlocks (Reverse, Shuffle01) |
| template <class D, class V = VFromD<D>> |
| HWY_API V ReverseBlocks(D d, V v) { |
| #if HWY_TARGET == HWY_SVE_256 |
| if (detail::IsFull(d)) { |
| return SwapAdjacentBlocks(v); |
| } else if (detail::IsFull(Twice<D>())) { |
| return v; |
| } |
| #elif HWY_TARGET == HWY_SVE2_128 |
| (void)d; |
| return v; |
| #endif |
| const Repartition<uint64_t, D> du64; |
| return BitCast(d, Shuffle01(Reverse(du64, BitCast(du64, v)))); |
| } |
| |
| // ------------------------------ TableLookupBytes |
| |
| template <class V, class VI> |
| HWY_API VI TableLookupBytes(const V v, const VI idx) { |
| const DFromV<VI> d; |
| const Repartition<uint8_t, decltype(d)> du8; |
| #if HWY_TARGET == HWY_SVE2_128 |
| return BitCast(d, TableLookupLanes(BitCast(du8, v), BitCast(du8, idx))); |
| #else |
| const auto offsets128 = detail::OffsetsOf128BitBlocks(du8, Iota(du8, 0)); |
| const auto idx8 = Add(BitCast(du8, idx), offsets128); |
| return BitCast(d, TableLookupLanes(BitCast(du8, v), idx8)); |
| #endif |
| } |
| |
| template <class V, class VI> |
| HWY_API VI TableLookupBytesOr0(const V v, const VI idx) { |
| const DFromV<VI> d; |
| // Mask size must match vector type, so cast everything to this type. |
| const Repartition<int8_t, decltype(d)> di8; |
| |
| auto idx8 = BitCast(di8, idx); |
| const auto msb = detail::LtN(idx8, 0); |
| |
| const auto lookup = TableLookupBytes(BitCast(di8, v), idx8); |
| return BitCast(d, IfThenZeroElse(msb, lookup)); |
| } |
| |
| // ------------------------------ Broadcast |
| |
| #ifdef HWY_NATIVE_BROADCASTLANE |
| #undef HWY_NATIVE_BROADCASTLANE |
| #else |
| #define HWY_NATIVE_BROADCASTLANE |
| #endif |
| |
| namespace detail { |
| #define HWY_SVE_BROADCAST(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| template <int kLane> \ |
| HWY_INLINE HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ |
| return sv##OP##_##CHAR##BITS(v, kLane); \ |
| } |
| |
| HWY_SVE_FOREACH(HWY_SVE_BROADCAST, BroadcastLane, dup_lane) |
| #undef HWY_SVE_BROADCAST |
| } // namespace detail |
| |
| template <int kLane, class V> |
| HWY_API V Broadcast(const V v) { |
| const DFromV<V> d; |
| const RebindToUnsigned<decltype(d)> du; |
| constexpr size_t kLanesPerBlock = detail::LanesPerBlock(du); |
| static_assert(0 <= kLane && kLane < kLanesPerBlock, "Invalid lane"); |
| #if HWY_TARGET == HWY_SVE2_128 |
| return detail::BroadcastLane<kLane>(v); |
| #else |
| auto idx = detail::OffsetsOf128BitBlocks(du, Iota(du, 0)); |
| if (kLane != 0) { |
| idx = detail::AddN(idx, kLane); |
| } |
| return TableLookupLanes(v, idx); |
| #endif |
| } |
| |
| template <int kLane, class V> |
| HWY_API V BroadcastLane(const V v) { |
| static_assert(0 <= kLane && kLane < HWY_MAX_LANES_V(V), "Invalid lane"); |
| return detail::BroadcastLane<kLane>(v); |
| } |
| |
| // ------------------------------ ShiftLeftLanes |
| |
| template <size_t kLanes, class D, class V = VFromD<D>> |
| HWY_API V ShiftLeftLanes(D d, const V v) { |
| const auto zero = Zero(d); |
| const auto shifted = detail::Splice(v, zero, FirstN(d, kLanes)); |
| #if HWY_TARGET == HWY_SVE2_128 |
| return shifted; |
| #else |
| // Match x86 semantics by zeroing lower lanes in 128-bit blocks |
| return IfThenElse(detail::FirstNPerBlock<kLanes>(d), zero, shifted); |
| #endif |
| } |
| |
| template <size_t kLanes, class V> |
| HWY_API V ShiftLeftLanes(const V v) { |
| return ShiftLeftLanes<kLanes>(DFromV<V>(), v); |
| } |
| |
| // ------------------------------ ShiftRightLanes |
| template <size_t kLanes, class D, class V = VFromD<D>> |
| HWY_API V ShiftRightLanes(D d, V v) { |
| // For capped/fractional vectors, clear upper lanes so we shift in zeros. |
| if (!detail::IsFull(d)) { |
| v = IfThenElseZero(detail::MakeMask(d), v); |
| } |
| |
| #if HWY_TARGET == HWY_SVE2_128 |
| return detail::Ext<kLanes>(Zero(d), v); |
| #else |
| const auto shifted = detail::Ext<kLanes>(v, v); |
| // Match x86 semantics by zeroing upper lanes in 128-bit blocks |
| constexpr size_t kLanesPerBlock = detail::LanesPerBlock(d); |
| const svbool_t mask = detail::FirstNPerBlock<kLanesPerBlock - kLanes>(d); |
| return IfThenElseZero(mask, shifted); |
| #endif |
| } |
| |
| // ------------------------------ ShiftLeftBytes |
| |
| template <int kBytes, class D, class V = VFromD<D>> |
| HWY_API V ShiftLeftBytes(const D d, const V v) { |
| const Repartition<uint8_t, decltype(d)> d8; |
| return BitCast(d, ShiftLeftLanes<kBytes>(BitCast(d8, v))); |
| } |
| |
| template <int kBytes, class V> |
| HWY_API V ShiftLeftBytes(const V v) { |
| return ShiftLeftBytes<kBytes>(DFromV<V>(), v); |
| } |
| |
| // ------------------------------ ShiftRightBytes |
| template <int kBytes, class D, class V = VFromD<D>> |
| HWY_API V ShiftRightBytes(const D d, const V v) { |
| const Repartition<uint8_t, decltype(d)> d8; |
| return BitCast(d, ShiftRightLanes<kBytes>(d8, BitCast(d8, v))); |
| } |
| |
| // ------------------------------ ZipLower |
| |
| template <class V, class DW = RepartitionToWide<DFromV<V>>> |
| HWY_API VFromD<DW> ZipLower(DW dw, V a, V b) { |
| const RepartitionToNarrow<DW> dn; |
| static_assert(IsSame<TFromD<decltype(dn)>, TFromV<V>>(), "D/V mismatch"); |
| return BitCast(dw, InterleaveLower(dn, a, b)); |
| } |
| template <class V, class D = DFromV<V>, class DW = RepartitionToWide<D>> |
| HWY_API VFromD<DW> ZipLower(const V a, const V b) { |
| return BitCast(DW(), InterleaveLower(D(), a, b)); |
| } |
| |
| // ------------------------------ ZipUpper |
| template <class V, class DW = RepartitionToWide<DFromV<V>>> |
| HWY_API VFromD<DW> ZipUpper(DW dw, V a, V b) { |
| const RepartitionToNarrow<DW> dn; |
| static_assert(IsSame<TFromD<decltype(dn)>, TFromV<V>>(), "D/V mismatch"); |
| return BitCast(dw, InterleaveUpper(dn, a, b)); |
| } |
| |
| // ================================================== Ops with dependencies |
| |
| // ------------------------------ AddSub (Reverse2) |
| |
| // NOTE: svcadd_f*_x(HWY_SVE_PTRUE(BITS), a, b, 90) computes a[i] - b[i + 1] in |
| // the even lanes and a[i] + b[i - 1] in the odd lanes. |
| |
| #define HWY_SVE_ADDSUB_F(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ |
| const DFromV<decltype(b)> d; \ |
| return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), a, Reverse2(d, b), \ |
| 90); \ |
| } |
| |
| HWY_SVE_FOREACH_F(HWY_SVE_ADDSUB_F, AddSub, cadd) |
| |
| #undef HWY_SVE_ADDSUB_F |
| |
| // NOTE: svcadd_s*(a, b, 90) and svcadd_u*(a, b, 90) compute a[i] - b[i + 1] in |
| // the even lanes and a[i] + b[i - 1] in the odd lanes. |
| |
| #if HWY_SVE_HAVE_2 |
| #define HWY_SVE_ADDSUB_UI(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ |
| const DFromV<decltype(b)> d; \ |
| return sv##OP##_##CHAR##BITS(a, Reverse2(d, b), 90); \ |
| } |
| |
| HWY_SVE_FOREACH_UI(HWY_SVE_ADDSUB_UI, AddSub, cadd) |
| |
| #undef HWY_SVE_ADDSUB_UI |
| |
| // Disable the default implementation of AddSub in generic_ops-inl.h on SVE2 |
| #undef HWY_IF_ADDSUB_V |
| #define HWY_IF_ADDSUB_V(V) \ |
| HWY_IF_LANES_GT_D(DFromV<V>, 1), \ |
| hwy::EnableIf<!hwy::IsSame<V, V>()>* = nullptr |
| |
| #else // !HWY_SVE_HAVE_2 |
| |
| // Disable the default implementation of AddSub in generic_ops-inl.h for |
| // floating-point vectors on SVE, but enable the default implementation of |
| // AddSub in generic_ops-inl.h for integer vectors on SVE that do not support |
| // SVE2 |
| #undef HWY_IF_ADDSUB_V |
| #define HWY_IF_ADDSUB_V(V) \ |
| HWY_IF_LANES_GT_D(DFromV<V>, 1), HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V) |
| |
| #endif // HWY_SVE_HAVE_2 |
| |
| // ------------------------------ MulAddSub (AddSub) |
| |
| template <class V, HWY_IF_LANES_GT_D(DFromV<V>, 1), HWY_IF_FLOAT_V(V)> |
| HWY_API V MulAddSub(V mul, V x, V sub_or_add) { |
| using T = TFromV<V>; |
| |
| const DFromV<V> d; |
| const T neg_zero = ConvertScalarTo<T>(-0.0f); |
| |
| return MulAdd(mul, x, AddSub(Set(d, neg_zero), sub_or_add)); |
| } |
| |
| #if HWY_SVE_HAVE_2 |
| |
| // Disable the default implementation of MulAddSub in generic_ops-inl.h on SVE2 |
| #undef HWY_IF_MULADDSUB_V |
| #define HWY_IF_MULADDSUB_V(V) \ |
| HWY_IF_LANES_GT_D(DFromV<V>, 1), \ |
| hwy::EnableIf<!hwy::IsSame<V, V>()>* = nullptr |
| |
| template <class V, HWY_IF_LANES_GT_D(DFromV<V>, 1), |
| HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V)> |
| HWY_API V MulAddSub(V mul, V x, V sub_or_add) { |
| const DFromV<V> d; |
| return MulAdd(mul, x, AddSub(Zero(d), sub_or_add)); |
| } |
| |
| #else // !HWY_SVE_HAVE_2 |
| |
| // Disable the default implementation of MulAddSub in generic_ops-inl.h for |
| // floating-point vectors on SVE, but enable the default implementation of |
| // AddSub in generic_ops-inl.h for integer vectors on SVE targets that do not |
| // support SVE2 |
| #undef HWY_IF_MULADDSUB_V |
| #define HWY_IF_MULADDSUB_V(V) \ |
| HWY_IF_LANES_GT_D(DFromV<V>, 1), HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V) |
| |
| #endif // HWY_SVE_HAVE_2 |
| |
| // ------------------------------ PromoteTo bfloat16 (ZipLower) |
| template <size_t N, int kPow2> |
| HWY_API svfloat32_t PromoteTo(Simd<float32_t, N, kPow2> df32, VBF16 v) { |
| const ScalableTag<uint16_t> du16; |
| return BitCast(df32, detail::ZipLowerSame(svdup_n_u16(0), BitCast(du16, v))); |
| } |
| |
| // ------------------------------ PromoteEvenTo/PromoteOddTo (ConcatOddFull) |
| |
| namespace detail { |
| |
| // Signed to signed PromoteEvenTo |
| template <class D> |
| HWY_INLINE VFromD<D> PromoteEvenTo(hwy::SignedTag /*to_type_tag*/, |
| hwy::SizeTag<2> /*to_lane_size_tag*/, |
| hwy::SignedTag /*from_type_tag*/, D d_to, |
| svint8_t v) { |
| return svextb_s16_x(detail::PTrue(d_to), BitCast(d_to, v)); |
| } |
| |
| template <class D> |
| HWY_INLINE VFromD<D> PromoteEvenTo(hwy::SignedTag /*to_type_tag*/, |
| hwy::SizeTag<4> /*to_lane_size_tag*/, |
| hwy::SignedTag /*from_type_tag*/, D d_to, |
| svint16_t v) { |
| return svexth_s32_x(detail::PTrue(d_to), BitCast(d_to, v)); |
| } |
| |
| template <class D> |
| HWY_INLINE VFromD<D> PromoteEvenTo(hwy::SignedTag /*to_type_tag*/, |
| hwy::SizeTag<8> /*to_lane_size_tag*/, |
| hwy::SignedTag /*from_type_tag*/, D d_to, |
| svint32_t v) { |
| return svextw_s64_x(detail::PTrue(d_to), BitCast(d_to, v)); |
| } |
| |
| // F16->F32 PromoteEvenTo |
| template <class D> |
| HWY_INLINE VFromD<D> PromoteEvenTo(hwy::FloatTag /*to_type_tag*/, |
| hwy::SizeTag<4> /*to_lane_size_tag*/, |
| hwy::FloatTag /*from_type_tag*/, D d_to, |
| svfloat16_t v) { |
| const Repartition<float, decltype(d_to)> d_from; |
| return svcvt_f32_f16_x(detail::PTrue(d_from), v); |
| } |
| |
| // F32->F64 PromoteEvenTo |
| template <class D> |
| HWY_INLINE VFromD<D> PromoteEvenTo(hwy::FloatTag /*to_type_tag*/, |
| hwy::SizeTag<8> /*to_lane_size_tag*/, |
| hwy::FloatTag /*from_type_tag*/, D d_to, |
| svfloat32_t v) { |
| const Repartition<float, decltype(d_to)> d_from; |
| return svcvt_f64_f32_x(detail::PTrue(d_from), v); |
| } |
| |
| // I32->F64 PromoteEvenTo |
| template <class D> |
| HWY_INLINE VFromD<D> PromoteEvenTo(hwy::FloatTag /*to_type_tag*/, |
| hwy::SizeTag<8> /*to_lane_size_tag*/, |
| hwy::SignedTag /*from_type_tag*/, D d_to, |
| svint32_t v) { |
| const Repartition<float, decltype(d_to)> d_from; |
| return svcvt_f64_s32_x(detail::PTrue(d_from), v); |
| } |
| |
| // U32->F64 PromoteEvenTo |
| template <class D> |
| HWY_INLINE VFromD<D> PromoteEvenTo(hwy::FloatTag /*to_type_tag*/, |
| hwy::SizeTag<8> /*to_lane_size_tag*/, |
| hwy::UnsignedTag /*from_type_tag*/, D d_to, |
| svuint32_t v) { |
| const Repartition<float, decltype(d_to)> d_from; |
| return svcvt_f64_u32_x(detail::PTrue(d_from), v); |
| } |
| |
| // F32->I64 PromoteEvenTo |
| template <class D> |
| HWY_INLINE VFromD<D> PromoteEvenTo(hwy::SignedTag /*to_type_tag*/, |
| hwy::SizeTag<8> /*to_lane_size_tag*/, |
| hwy::FloatTag /*from_type_tag*/, D d_to, |
| svfloat32_t v) { |
| const Repartition<float, decltype(d_to)> d_from; |
| return svcvt_s64_f32_x(detail::PTrue(d_from), v); |
| } |
| |
| // F32->U64 PromoteEvenTo |
| template <class D> |
| HWY_INLINE VFromD<D> PromoteEvenTo(hwy::UnsignedTag /*to_type_tag*/, |
| hwy::SizeTag<8> /*to_lane_size_tag*/, |
| hwy::FloatTag /*from_type_tag*/, D d_to, |
| svfloat32_t v) { |
| const Repartition<float, decltype(d_to)> d_from; |
| return svcvt_u64_f32_x(detail::PTrue(d_from), v); |
| } |
| |
| // F16->F32 PromoteOddTo |
| template <class D> |
| HWY_INLINE VFromD<D> PromoteOddTo(hwy::FloatTag to_type_tag, |
| hwy::SizeTag<4> to_lane_size_tag, |
| hwy::FloatTag from_type_tag, D d_to, |
| svfloat16_t v) { |
| return PromoteEvenTo(to_type_tag, to_lane_size_tag, from_type_tag, d_to, |
| DupOdd(v)); |
| } |
| |
| // I32/U32/F32->F64 PromoteOddTo |
| template <class FromTypeTag, class D, class V> |
| HWY_INLINE VFromD<D> PromoteOddTo(hwy::FloatTag to_type_tag, |
| hwy::SizeTag<8> to_lane_size_tag, |
| FromTypeTag from_type_tag, D d_to, V v) { |
| return PromoteEvenTo(to_type_tag, to_lane_size_tag, from_type_tag, d_to, |
| DupOdd(v)); |
| } |
| |
| // F32->I64/U64 PromoteOddTo |
| template <class ToTypeTag, class D, HWY_IF_UI64_D(D)> |
| HWY_INLINE VFromD<D> PromoteOddTo(ToTypeTag to_type_tag, |
| hwy::SizeTag<8> to_lane_size_tag, |
| hwy::FloatTag from_type_tag, D d_to, |
| svfloat32_t v) { |
| return PromoteEvenTo(to_type_tag, to_lane_size_tag, from_type_tag, d_to, |
| DupOdd(v)); |
| } |
| |
| } // namespace detail |
| |
| // ------------------------------ ReorderDemote2To (OddEven) |
| |
| template <size_t N, int kPow2> |
| HWY_API VBF16 ReorderDemote2To(Simd<bfloat16_t, N, kPow2> dbf16, svfloat32_t a, |
| svfloat32_t b) { |
| #if HWY_SVE_HAVE_F32_TO_BF16C |
| const VBF16 b_in_even = svcvt_bf16_f32_x(detail::PTrue(dbf16), b); |
| return svcvtnt_bf16_f32_x(b_in_even, detail::PTrue(dbf16), a); |
| #else |
| (void)dbf16; |
| const auto a_in_odd = |
| BitCast(ScalableTag<uint16_t>(), detail::RoundF32ForDemoteToBF16(a)); |
| const auto b_in_odd = |
| BitCast(ScalableTag<uint16_t>(), detail::RoundF32ForDemoteToBF16(b)); |
| return BitCast(dbf16, detail::InterleaveOdd(b_in_odd, a_in_odd)); |
| #endif |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API svint16_t ReorderDemote2To(Simd<int16_t, N, kPow2> d16, svint32_t a, |
| svint32_t b) { |
| #if HWY_SVE_HAVE_2 |
| (void)d16; |
| const svint16_t a_in_even = svqxtnb_s32(a); |
| return svqxtnt_s32(a_in_even, b); |
| #else |
| const svint16_t a16 = BitCast(d16, detail::SaturateI<int16_t>(a)); |
| const svint16_t b16 = BitCast(d16, detail::SaturateI<int16_t>(b)); |
| return detail::InterleaveEven(a16, b16); |
| #endif |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API svuint16_t ReorderDemote2To(Simd<uint16_t, N, kPow2> d16, svint32_t a, |
| svint32_t b) { |
| #if HWY_SVE_HAVE_2 |
| (void)d16; |
| const svuint16_t a_in_even = svqxtunb_s32(a); |
| return svqxtunt_s32(a_in_even, b); |
| #else |
| const Repartition<uint32_t, decltype(d16)> du32; |
| const svuint32_t clamped_a = BitCast(du32, detail::MaxN(a, 0)); |
| const svuint32_t clamped_b = BitCast(du32, detail::MaxN(b, 0)); |
| const svuint16_t a16 = BitCast(d16, detail::SaturateU<uint16_t>(clamped_a)); |
| const svuint16_t b16 = BitCast(d16, detail::SaturateU<uint16_t>(clamped_b)); |
| return detail::InterleaveEven(a16, b16); |
| #endif |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API svuint16_t ReorderDemote2To(Simd<uint16_t, N, kPow2> d16, svuint32_t a, |
| svuint32_t b) { |
| #if HWY_SVE_HAVE_2 |
| (void)d16; |
| const svuint16_t a_in_even = svqxtnb_u32(a); |
| return svqxtnt_u32(a_in_even, b); |
| #else |
| const svuint16_t a16 = BitCast(d16, detail::SaturateU<uint16_t>(a)); |
| const svuint16_t b16 = BitCast(d16, detail::SaturateU<uint16_t>(b)); |
| return detail::InterleaveEven(a16, b16); |
| #endif |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API svint8_t ReorderDemote2To(Simd<int8_t, N, kPow2> d8, svint16_t a, |
| svint16_t b) { |
| #if HWY_SVE_HAVE_2 |
| (void)d8; |
| const svint8_t a_in_even = svqxtnb_s16(a); |
| return svqxtnt_s16(a_in_even, b); |
| #else |
| const svint8_t a8 = BitCast(d8, detail::SaturateI<int8_t>(a)); |
| const svint8_t b8 = BitCast(d8, detail::SaturateI<int8_t>(b)); |
| return detail::InterleaveEven(a8, b8); |
| #endif |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API svuint8_t ReorderDemote2To(Simd<uint8_t, N, kPow2> d8, svint16_t a, |
| svint16_t b) { |
| #if HWY_SVE_HAVE_2 |
| (void)d8; |
| const svuint8_t a_in_even = svqxtunb_s16(a); |
| return svqxtunt_s16(a_in_even, b); |
| #else |
| const Repartition<uint16_t, decltype(d8)> du16; |
| const svuint16_t clamped_a = BitCast(du16, detail::MaxN(a, 0)); |
| const svuint16_t clamped_b = BitCast(du16, detail::MaxN(b, 0)); |
| const svuint8_t a8 = BitCast(d8, detail::SaturateU<uint8_t>(clamped_a)); |
| const svuint8_t b8 = BitCast(d8, detail::SaturateU<uint8_t>(clamped_b)); |
| return detail::InterleaveEven(a8, b8); |
| #endif |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API svuint8_t ReorderDemote2To(Simd<uint8_t, N, kPow2> d8, svuint16_t a, |
| svuint16_t b) { |
| #if HWY_SVE_HAVE_2 |
| (void)d8; |
| const svuint8_t a_in_even = svqxtnb_u16(a); |
| return svqxtnt_u16(a_in_even, b); |
| #else |
| const svuint8_t a8 = BitCast(d8, detail::SaturateU<uint8_t>(a)); |
| const svuint8_t b8 = BitCast(d8, detail::SaturateU<uint8_t>(b)); |
| return detail::InterleaveEven(a8, b8); |
| #endif |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API svint32_t ReorderDemote2To(Simd<int32_t, N, kPow2> d32, svint64_t a, |
| svint64_t b) { |
| #if HWY_SVE_HAVE_2 |
| (void)d32; |
| const svint32_t a_in_even = svqxtnb_s64(a); |
| return svqxtnt_s64(a_in_even, b); |
| #else |
| const svint32_t a32 = BitCast(d32, detail::SaturateI<int32_t>(a)); |
| const svint32_t b32 = BitCast(d32, detail::SaturateI<int32_t>(b)); |
| return detail::InterleaveEven(a32, b32); |
| #endif |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API svuint32_t ReorderDemote2To(Simd<uint32_t, N, kPow2> d32, svint64_t a, |
| svint64_t b) { |
| #if HWY_SVE_HAVE_2 |
| (void)d32; |
| const svuint32_t a_in_even = svqxtunb_s64(a); |
| return svqxtunt_s64(a_in_even, b); |
| #else |
| const Repartition<uint64_t, decltype(d32)> du64; |
| const svuint64_t clamped_a = BitCast(du64, detail::MaxN(a, 0)); |
| const svuint64_t clamped_b = BitCast(du64, detail::MaxN(b, 0)); |
| const svuint32_t a32 = BitCast(d32, detail::SaturateU<uint32_t>(clamped_a)); |
| const svuint32_t b32 = BitCast(d32, detail::SaturateU<uint32_t>(clamped_b)); |
| return detail::InterleaveEven(a32, b32); |
| #endif |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API svuint32_t ReorderDemote2To(Simd<uint32_t, N, kPow2> d32, svuint64_t a, |
| svuint64_t b) { |
| #if HWY_SVE_HAVE_2 |
| (void)d32; |
| const svuint32_t a_in_even = svqxtnb_u64(a); |
| return svqxtnt_u64(a_in_even, b); |
| #else |
| const svuint32_t a32 = BitCast(d32, detail::SaturateU<uint32_t>(a)); |
| const svuint32_t b32 = BitCast(d32, detail::SaturateU<uint32_t>(b)); |
| return detail::InterleaveEven(a32, b32); |
| #endif |
| } |
| |
| template <class D, class V, HWY_IF_SIGNED_D(D), HWY_IF_UNSIGNED_V(V), |
| HWY_IF_T_SIZE_D(D, sizeof(TFromV<V>) / 2)> |
| HWY_API VFromD<D> ReorderDemote2To(D dn, V a, V b) { |
| const auto clamped_a = BitCast(dn, detail::SaturateU<TFromD<D>>(a)); |
| const auto clamped_b = BitCast(dn, detail::SaturateU<TFromD<D>>(b)); |
| return detail::InterleaveEven(clamped_a, clamped_b); |
| } |
| |
| template <class D, class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL(TFromD<D>), |
| HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V), |
| HWY_IF_T_SIZE_V(V, sizeof(TFromD<D>) * 2)> |
| HWY_API VFromD<D> OrderedDemote2To(D dn, V a, V b) { |
| const Half<decltype(dn)> dnh; |
| const auto demoted_a = DemoteTo(dnh, a); |
| const auto demoted_b = DemoteTo(dnh, b); |
| return Combine(dn, demoted_b, demoted_a); |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API VBF16 OrderedDemote2To(Simd<bfloat16_t, N, kPow2> dbf16, svfloat32_t a, |
| svfloat32_t b) { |
| #if HWY_SVE_HAVE_F32_TO_BF16C |
| (void)dbf16; |
| const VBF16 a_in_even = svcvt_bf16_f32_x(detail::PTrue(dbf16), a); |
| const VBF16 b_in_even = svcvt_bf16_f32_x(detail::PTrue(dbf16), b); |
| return ConcatEven(dbf16, b_in_even, a_in_even); |
| #else |
| const RebindToUnsigned<decltype(dbf16)> du16; |
| const svuint16_t a_in_odd = BitCast(du16, detail::RoundF32ForDemoteToBF16(a)); |
| const svuint16_t b_in_odd = BitCast(du16, detail::RoundF32ForDemoteToBF16(b)); |
| return BitCast(dbf16, ConcatOdd(du16, b_in_odd, a_in_odd)); // lower half |
| #endif |
| } |
| |
| // ------------------------------ I8/U8/I16/U16 Div |
| |
| template <class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V), |
| HWY_IF_T_SIZE_ONE_OF_V(V, (1 << 1) | (1 << 2))> |
| HWY_API V Div(V a, V b) { |
| const DFromV<decltype(a)> d; |
| const Half<decltype(d)> dh; |
| const RepartitionToWide<decltype(d)> dw; |
| |
| const auto q_lo = |
| Div(PromoteTo(dw, LowerHalf(dh, a)), PromoteTo(dw, LowerHalf(dh, b))); |
| const auto q_hi = Div(PromoteUpperTo(dw, a), PromoteUpperTo(dw, b)); |
| |
| return OrderedDemote2To(d, q_lo, q_hi); |
| } |
| |
| // ------------------------------ I8/U8/I16/U16 MaskedDivOr |
| template <class V, class M, HWY_IF_T_SIZE_ONE_OF_V(V, (1 << 1) | (1 << 2)), |
| HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V)> |
| HWY_API V MaskedDivOr(V no, M m, V a, V b) { |
| return IfThenElse(m, Div(a, b), no); |
| } |
| |
| template <class V, class M, HWY_IF_T_SIZE_ONE_OF_V(V, (1 << 1) | (1 << 2)), |
| HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V)> |
| HWY_API V MaskedDiv(M m, V a, V b) { |
| return IfThenElseZero(m, Div(a, b)); |
| } |
| |
| // ------------------------------ Mod (Div, NegMulAdd) |
| template <class V> |
| HWY_API V Mod(V a, V b) { |
| return NegMulAdd(Div(a, b), b, a); |
| } |
| |
| // ------------------------------ MaskedModOr (Mod) |
| template <class V, class M> |
| HWY_API V MaskedModOr(V no, M m, V a, V b) { |
| return IfThenElse(m, Mod(a, b), no); |
| } |
| |
| // ------------------------------ IfNegativeThenElse (BroadcastSignBit) |
| template <class V> |
| HWY_API V IfNegativeThenElse(V v, V yes, V no) { |
| static_assert(IsSigned<TFromV<V>>(), "Only works for signed/float"); |
| return IfThenElse(IsNegative(v), yes, no); |
| } |
| // ------------------------------ IfNegativeThenNegOrUndefIfZero |
| |
| #ifdef HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG |
| #undef HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG |
| #else |
| #define HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG |
| #endif |
| |
| #define HWY_SVE_NEG_IF(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| NAME(HWY_SVE_V(BASE, BITS) mask, HWY_SVE_V(BASE, BITS) v) { \ |
| return sv##OP##_##CHAR##BITS##_m(v, IsNegative(mask), v); \ |
| } |
| |
| HWY_SVE_FOREACH_IF(HWY_SVE_NEG_IF, IfNegativeThenNegOrUndefIfZero, neg) |
| |
| #undef HWY_SVE_NEG_IF |
| |
| // ------------------------------ AverageRound (ShiftRight) |
| |
| #ifdef HWY_NATIVE_AVERAGE_ROUND_UI32 |
| #undef HWY_NATIVE_AVERAGE_ROUND_UI32 |
| #else |
| #define HWY_NATIVE_AVERAGE_ROUND_UI32 |
| #endif |
| |
| #ifdef HWY_NATIVE_AVERAGE_ROUND_UI64 |
| #undef HWY_NATIVE_AVERAGE_ROUND_UI64 |
| #else |
| #define HWY_NATIVE_AVERAGE_ROUND_UI64 |
| #endif |
| |
| #if HWY_SVE_HAVE_2 |
| HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, AverageRound, rhadd) |
| #else |
| template <class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V)> |
| HWY_API V AverageRound(const V a, const V b) { |
| return Sub(Or(a, b), ShiftRight<1>(Xor(a, b))); |
| } |
| #endif // HWY_SVE_HAVE_2 |
| |
| // ------------------------------ LoadMaskBits (TestBit) |
| |
| // `p` points to at least 8 readable bytes, not all of which need be valid. |
| template <class D, HWY_IF_T_SIZE_D(D, 1)> |
| HWY_INLINE svbool_t LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) { |
| #if HWY_COMPILER_CLANG >= 1901 || HWY_COMPILER_GCC_ACTUAL >= 1200 |
| typedef svbool_t UnalignedSveMaskT |
| __attribute__((__aligned__(1), __may_alias__)); |
| (void)d; |
| return *reinterpret_cast<const UnalignedSveMaskT*>(bits); |
| #else |
| // TODO(janwas): with SVE2.1, load to vector, then PMOV |
| const RebindToUnsigned<D> du; |
| const svuint8_t iota = Iota(du, 0); |
| |
| // Load correct number of bytes (bits/8) with 7 zeros after each. |
| const svuint8_t bytes = BitCast(du, svld1ub_u64(detail::PTrue(d), bits)); |
| // Replicate bytes 8x such that each byte contains the bit that governs it. |
| const svuint8_t rep8 = svtbl_u8(bytes, detail::AndNotN(7, iota)); |
| |
| const svuint8_t bit = |
| svdupq_n_u8(1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128); |
| return TestBit(rep8, bit); |
| #endif |
| } |
| |
| template <class D, HWY_IF_T_SIZE_D(D, 2)> |
| HWY_INLINE svbool_t LoadMaskBits(D /* tag */, |
| const uint8_t* HWY_RESTRICT bits) { |
| const RebindToUnsigned<D> du; |
| const Repartition<uint8_t, D> du8; |
| |
| // There may be up to 128 bits; avoid reading past the end. |
| const svuint8_t bytes = svld1(FirstN(du8, (Lanes(du) + 7) / 8), bits); |
| |
| // Replicate bytes 16x such that each lane contains the bit that governs it. |
| const svuint8_t rep16 = svtbl_u8(bytes, ShiftRight<4>(Iota(du8, 0))); |
| |
| const svuint16_t bit = svdupq_n_u16(1, 2, 4, 8, 16, 32, 64, 128); |
| return TestBit(BitCast(du, rep16), bit); |
| } |
| |
| template <class D, HWY_IF_T_SIZE_D(D, 4)> |
| HWY_INLINE svbool_t LoadMaskBits(D /* tag */, |
| const uint8_t* HWY_RESTRICT bits) { |
| const RebindToUnsigned<D> du; |
| const Repartition<uint8_t, D> du8; |
| |
| // Upper bound = 2048 bits / 32 bit = 64 bits; at least 8 bytes are readable, |
| // so we can skip computing the actual length (Lanes(du)+7)/8. |
| const svuint8_t bytes = svld1(FirstN(du8, 8), bits); |
| |
| // Replicate bytes 32x such that each lane contains the bit that governs it. |
| const svuint8_t rep32 = svtbl_u8(bytes, ShiftRight<5>(Iota(du8, 0))); |
| |
| // 1, 2, 4, 8, 16, 32, 64, 128, 1, 2 .. |
| const svuint32_t bit = Shl(Set(du, 1), detail::AndN(Iota(du, 0), 7)); |
| |
| return TestBit(BitCast(du, rep32), bit); |
| } |
| |
| template <class D, HWY_IF_T_SIZE_D(D, 8)> |
| HWY_INLINE svbool_t LoadMaskBits(D /* tag */, |
| const uint8_t* HWY_RESTRICT bits) { |
| const RebindToUnsigned<D> du; |
| |
| // Max 2048 bits = 32 lanes = 32 input bits; replicate those into each lane. |
| // The "at least 8 byte" guarantee in quick_reference ensures this is safe. |
| uint32_t mask_bits; |
| CopyBytes<4>(bits, &mask_bits); // copy from bytes |
| const auto vbits = Set(du, mask_bits); |
| |
| // 2 ^ {0,1, .., 31}, will not have more lanes than that. |
| const svuint64_t bit = Shl(Set(du, 1), Iota(du, 0)); |
| |
| return TestBit(vbits, bit); |
| } |
| |
| // ------------------------------ Dup128MaskFromMaskBits |
| |
| template <class D, HWY_IF_T_SIZE_D(D, 1), HWY_IF_V_SIZE_LE_D(D, 8)> |
| HWY_API MFromD<D> Dup128MaskFromMaskBits(D d, unsigned mask_bits) { |
| const RebindToUnsigned<decltype(d)> du; |
| |
| constexpr size_t kN = MaxLanes(d); |
| if (kN < 8) mask_bits &= (1u << kN) - 1; |
| |
| // Replicate the lower 8 bits of mask_bits to each u8 lane |
| const svuint8_t bytes = BitCast(du, Set(du, static_cast<uint8_t>(mask_bits))); |
| |
| const svuint8_t bit = |
| svdupq_n_u8(1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128); |
| return TestBit(bytes, bit); |
| } |
| |
| template <class D, HWY_IF_T_SIZE_D(D, 1), HWY_IF_V_SIZE_GT_D(D, 8)> |
| HWY_API MFromD<D> Dup128MaskFromMaskBits(D d, unsigned mask_bits) { |
| const RebindToUnsigned<decltype(d)> du; |
| const Repartition<uint16_t, decltype(du)> du16; |
| |
| // Replicate the lower 16 bits of mask_bits to each u16 lane of a u16 vector, |
| // and then bitcast the replicated mask_bits to a u8 vector |
| const svuint8_t bytes = |
| BitCast(du, Set(du16, static_cast<uint16_t>(mask_bits))); |
| // Replicate bytes 8x such that each byte contains the bit that governs it. |
| const svuint8_t rep8 = svtbl_u8(bytes, ShiftRight<3>(Iota(du, 0))); |
| |
| const svuint8_t bit = |
| svdupq_n_u8(1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128); |
| return TestBit(rep8, bit); |
| } |
| |
| template <class D, HWY_IF_T_SIZE_D(D, 2)> |
| HWY_API MFromD<D> Dup128MaskFromMaskBits(D d, unsigned mask_bits) { |
| const RebindToUnsigned<decltype(d)> du; |
| const Repartition<uint8_t, decltype(d)> du8; |
| |
| constexpr size_t kN = MaxLanes(d); |
| if (kN < 8) mask_bits &= (1u << kN) - 1; |
| |
| // Set all of the u8 lanes of bytes to the lower 8 bits of mask_bits |
| const svuint8_t bytes = Set(du8, static_cast<uint8_t>(mask_bits)); |
| |
| const svuint16_t bit = svdupq_n_u16(1, 2, 4, 8, 16, 32, 64, 128); |
| return TestBit(BitCast(du, bytes), bit); |
| } |
| |
| template <class D, HWY_IF_T_SIZE_D(D, 4)> |
| HWY_API MFromD<D> Dup128MaskFromMaskBits(D d, unsigned mask_bits) { |
| const RebindToUnsigned<decltype(d)> du; |
| const Repartition<uint8_t, decltype(d)> du8; |
| |
| constexpr size_t kN = MaxLanes(d); |
| if (kN < 4) mask_bits &= (1u << kN) - 1; |
| |
| // Set all of the u8 lanes of bytes to the lower 8 bits of mask_bits |
| const svuint8_t bytes = Set(du8, static_cast<uint8_t>(mask_bits)); |
| |
| const svuint32_t bit = svdupq_n_u32(1, 2, 4, 8); |
| return TestBit(BitCast(du, bytes), bit); |
| } |
| |
| template <class D, HWY_IF_T_SIZE_D(D, 8)> |
| HWY_API MFromD<D> Dup128MaskFromMaskBits(D d, unsigned mask_bits) { |
| const RebindToUnsigned<decltype(d)> du; |
| const Repartition<uint8_t, decltype(d)> du8; |
| |
| if (MaxLanes(d) < 2) mask_bits &= 1u; |
| |
| // Set all of the u8 lanes of bytes to the lower 8 bits of mask_bits |
| const svuint8_t bytes = Set(du8, static_cast<uint8_t>(mask_bits)); |
| |
| const svuint64_t bit = svdupq_n_u64(1, 2); |
| return TestBit(BitCast(du, bytes), bit); |
| } |
| |
| // ------------------------------ StoreMaskBits (BitsFromMask) |
| |
| // `p` points to at least 8 writable bytes. |
| // TODO(janwas): with SVE2.1, use PMOV to store to vector, then StoreU |
| template <class D> |
| HWY_API size_t StoreMaskBits(D d, svbool_t m, uint8_t* bits) { |
| #if HWY_TARGET == HWY_SVE_256 || HWY_TARGET == HWY_SVE2_128 |
| constexpr size_t N = MaxLanes(d); |
| const uint64_t bits64 = BitsFromMask(d, m); |
| HWY_IF_CONSTEXPR(N < 8) { |
| // BitsFromMask guarantees upper bits are zero, hence no masking. |
| bits[0] = static_cast<uint8_t>(bits64); |
| } |
| else { |
| static_assert(N % 8 == 0, "N is pow2 >= 8, hence divisible"); |
| static_assert(HWY_IS_LITTLE_ENDIAN, ""); |
| hwy::CopyBytes<N / 8>(&bits64, bits); |
| } |
| constexpr size_t num_bytes = hwy::DivCeil(N, size_t{8}); |
| return num_bytes; |
| #else |
| svuint64_t bits_in_u64 = detail::BitsFromBool(detail::BoolFromMask<D>(m)); |
| |
| const size_t num_bits = Lanes(d); |
| const size_t num_bytes = hwy::DivCeil(num_bits, size_t{8}); |
| |
| // Truncate each u64 to 8 bits and store to u8. |
| svst1b_u64(FirstN(ScalableTag<uint64_t>(), num_bytes), bits, bits_in_u64); |
| |
| // Non-full byte, need to clear the undefined upper bits. Can happen for |
| // capped/fractional vectors or large T and small hardware vectors. |
| if (num_bits < 8) { |
| const int mask = static_cast<int>((1ull << num_bits) - 1); |
| bits[0] = static_cast<uint8_t>(bits[0] & mask); |
| } |
| // Else: we wrote full bytes because num_bits is a power of two >= 8. |
| |
| return num_bytes; |
| #endif // HWY_TARGET == HWY_SVE_256 || HWY_TARGET == HWY_SVE2_128 |
| } |
| |
| // ------------------------------ CompressBits (LoadMaskBits) |
| template <class V, HWY_IF_NOT_T_SIZE_V(V, 1)> |
| HWY_INLINE V CompressBits(V v, const uint8_t* HWY_RESTRICT bits) { |
| return Compress(v, LoadMaskBits(DFromV<V>(), bits)); |
| } |
| |
| // ------------------------------ CompressBitsStore (LoadMaskBits) |
| template <class D, HWY_IF_NOT_T_SIZE_D(D, 1)> |
| HWY_API size_t CompressBitsStore(VFromD<D> v, const uint8_t* HWY_RESTRICT bits, |
| D d, TFromD<D>* HWY_RESTRICT unaligned) { |
| return CompressStore(v, LoadMaskBits(d, bits), d, unaligned); |
| } |
| |
| // ------------------------------ Expand (StoreMaskBits) |
| |
| #ifdef HWY_NATIVE_EXPAND |
| #undef HWY_NATIVE_EXPAND |
| #else |
| #define HWY_NATIVE_EXPAND |
| #endif |
| |
| namespace detail { |
| |
| HWY_INLINE svuint8_t IndicesForExpandFromBits(uint64_t mask_bits) { |
| const CappedTag<uint8_t, 8> du8; |
| alignas(16) static constexpr uint8_t table[8 * 256] = { |
| // PrintExpand8x8Tables |
| 128, 128, 128, 128, 128, 128, 128, 128, // |
| 0, 128, 128, 128, 128, 128, 128, 128, // |
| 128, 0, 128, 128, 128, 128, 128, 128, // |
| 0, 1, 128, 128, 128, 128, 128, 128, // |
| 128, 128, 0, 128, 128, 128, 128, 128, // |
| 0, 128, 1, 128, 128, 128, 128, 128, // |
| 128, 0, 1, 128, 128, 128, 128, 128, // |
| 0, 1, 2, 128, 128, 128, 128, 128, // |
| 128, 128, 128, 0, 128, 128, 128, 128, // |
| 0, 128, 128, 1, 128, 128, 128, 128, // |
| 128, 0, 128, 1, 128, 128, 128, 128, // |
| 0, 1, 128, 2, 128, 128, 128, 128, // |
| 128, 128, 0, 1, 128, 128, 128, 128, // |
| 0, 128, 1, 2, 128, 128, 128, 128, // |
| 128, 0, 1, 2, 128, 128, 128, 128, // |
| 0, 1, 2, 3, 128, 128, 128, 128, // |
| 128, 128, 128, 128, 0, 128, 128, 128, // |
| 0, 128, 128, 128, 1, 128, 128, 128, // |
| 128, 0, 128, 128, 1, 128, 128, 128, // |
| 0, 1, 128, 128, 2, 128, 128, 128, // |
| 128, 128, 0, 128, 1, 128, 128, 128, // |
| 0, 128, 1, 128, 2, 128, 128, 128, // |
| 128, 0, 1, 128, 2, 128, 128, 128, // |
| 0, 1, 2, 128, 3, 128, 128, 128, // |
| 128, 128, 128, 0, 1, 128, 128, 128, // |
| 0, 128, 128, 1, 2, 128, 128, 128, // |
| 128, 0, 128, 1, 2, 128, 128, 128, // |
| 0, 1, 128, 2, 3, 128, 128, 128, // |
| 128, 128, 0, 1, 2, 128, 128, 128, // |
| 0, 128, 1, 2, 3, 128, 128, 128, // |
| 128, 0, 1, 2, 3, 128, 128, 128, // |
| 0, 1, 2, 3, 4, 128, 128, 128, // |
| 128, 128, 128, 128, 128, 0, 128, 128, // |
| 0, 128, 128, 128, 128, 1, 128, 128, // |
| 128, 0, 128, 128, 128, 1, 128, 128, // |
| 0, 1, 128, 128, 128, 2, 128, 128, // |
| 128, 128, 0, 128, 128, 1, 128, 128, // |
| 0, 128, 1, 128, 128, 2, 128, 128, // |
| 128, 0, 1, 128, 128, 2, 128, 128, // |
| 0, 1, 2, 128, 128, 3, 128, 128, // |
| 128, 128, 128, 0, 128, 1, 128, 128, // |
| 0, 128, 128, 1, 128, 2, 128, 128, // |
| 128, 0, 128, 1, 128, 2, 128, 128, // |
| 0, 1, 128, 2, 128, 3, 128, 128, // |
| 128, 128, 0, 1, 128, 2, 128, 128, // |
| 0, 128, 1, 2, 128, 3, 128, 128, // |
| 128, 0, 1, 2, 128, 3, 128, 128, // |
| 0, 1, 2, 3, 128, 4, 128, 128, // |
| 128, 128, 128, 128, 0, 1, 128, 128, // |
| 0, 128, 128, 128, 1, 2, 128, 128, // |
| 128, 0, 128, 128, 1, 2, 128, 128, // |
| 0, 1, 128, 128, 2, 3, 128, 128, // |
| 128, 128, 0, 128, 1, 2, 128, 128, // |
| 0, 128, 1, 128, 2, 3, 128, 128, // |
| 128, 0, 1, 128, 2, 3, 128, 128, // |
| 0, 1, 2, 128, 3, 4, 128, 128, // |
| 128, 128, 128, 0, 1, 2, 128, 128, // |
| 0, 128, 128, 1, 2, 3, 128, 128, // |
| 128, 0, 128, 1, 2, 3, 128, 128, // |
| 0, 1, 128, 2, 3, 4, 128, 128, // |
| 128, 128, 0, 1, 2, 3, 128, 128, // |
| 0, 128, 1, 2, 3, 4, 128, 128, // |
| 128, 0, 1, 2, 3, 4, 128, 128, // |
| 0, 1, 2, 3, 4, 5, 128, 128, // |
| 128, 128, 128, 128, 128, 128, 0, 128, // |
| 0, 128, 128, 128, 128, 128, 1, 128, // |
| 128, 0, 128, 128, 128, 128, 1, 128, // |
| 0, 1, 128, 128, 128, 128, 2, 128, // |
| 128, 128, 0, 128, 128, 128, 1, 128, // |
| 0, 128, 1, 128, 128, 128, 2, 128, // |
| 128, 0, 1, 128, 128, 128, 2, 128, // |
| 0, 1, 2, 128, 128, 128, 3, 128, // |
| 128, 128, 128, 0, 128, 128, 1, 128, // |
| 0, 128, 128, 1, 128, 128, 2, 128, // |
| 128, 0, 128, 1, 128, 128, 2, 128, // |
| 0, 1, 128, 2, 128, 128, 3, 128, // |
| 128, 128, 0, 1, 128, 128, 2, 128, // |
| 0, 128, 1, 2, 128, 128, 3, 128, // |
| 128, 0, 1, 2, 128, 128, 3, 128, // |
| 0, 1, 2, 3, 128, 128, 4, 128, // |
| 128, 128, 128, 128, 0, 128, 1, 128, // |
| 0, 128, 128, 128, 1, 128, 2, 128, // |
| 128, 0, 128, 128, 1, 128, 2, 128, // |
| 0, 1, 128, 128, 2, 128, 3, 128, // |
| 128, 128, 0, 128, 1, 128, 2, 128, // |
| 0, 128, 1, 128, 2, 128, 3, 128, // |
| 128, 0, 1, 128, 2, 128, 3, 128, // |
| 0, 1, 2, 128, 3, 128, 4, 128, // |
| 128, 128, 128, 0, 1, 128, 2, 128, // |
| 0, 128, 128, 1, 2, 128, 3, 128, // |
| 128, 0, 128, 1, 2, 128, 3, 128, // |
| 0, 1, 128, 2, 3, 128, 4, 128, // |
| 128, 128, 0, 1, 2, 128, 3, 128, // |
| 0, 128, 1, 2, 3, 128, 4, 128, // |
| 128, 0, 1, 2, 3, 128, 4, 128, // |
| 0, 1, 2, 3, 4, 128, 5, 128, // |
| 128, 128, 128, 128, 128, 0, 1, 128, // |
| 0, 128, 128, 128, 128, 1, 2, 128, // |
| 128, 0, 128, 128, 128, 1, 2, 128, // |
| 0, 1, 128, 128, 128, 2, 3, 128, // |
| 128, 128, 0, 128, 128, 1, 2, 128, // |
| 0, 128, 1, 128, 128, 2, 3, 128, // |
| 128, 0, 1, 128, 128, 2, 3, 128, // |
| 0, 1, 2, 128, 128, 3, 4, 128, // |
| 128, 128, 128, 0, 128, 1, 2, 128, // |
| 0, 128, 128, 1, 128, 2, 3, 128, // |
| 128, 0, 128, 1, 128, 2, 3, 128, // |
| 0, 1, 128, 2, 128, 3, 4, 128, // |
| 128, 128, 0, 1, 128, 2, 3, 128, // |
| 0, 128, 1, 2, 128, 3, 4, 128, // |
| 128, 0, 1, 2, 128, 3, 4, 128, // |
| 0, 1, 2, 3, 128, 4, 5, 128, // |
| 128, 128, 128, 128, 0, 1, 2, 128, // |
| 0, 128, 128, 128, 1, 2, 3, 128, // |
| 128, 0, 128, 128, 1, 2, 3, 128, // |
| 0, 1, 128, 128, 2, 3, 4, 128, // |
| 128, 128, 0, 128, 1, 2, 3, 128, // |
| 0, 128, 1, 128, 2, 3, 4, 128, // |
| 128, 0, 1, 128, 2, 3, 4, 128, // |
| 0, 1, 2, 128, 3, 4, 5, 128, // |
| 128, 128, 128, 0, 1, 2, 3, 128, // |
| 0, 128, 128, 1, 2, 3, 4, 128, // |
| 128, 0, 128, 1, 2, 3, 4, 128, // |
| 0, 1, 128, 2, 3, 4, 5, 128, // |
| 128, 128, 0, 1, 2, 3, 4, 128, // |
| 0, 128, 1, 2, 3, 4, 5, 128, // |
| 128, 0, 1, 2, 3, 4, 5, 128, // |
| 0, 1, 2, 3, 4, 5, 6, 128, // |
| 128, 128, 128, 128, 128, 128, 128, 0, // |
| 0, 128, 128, 128, 128, 128, 128, 1, // |
| 128, 0, 128, 128, 128, 128, 128, 1, // |
| 0, 1, 128, 128, 128, 128, 128, 2, // |
| 128, 128, 0, 128, 128, 128, 128, 1, // |
| 0, 128, 1, 128, 128, 128, 128, 2, // |
| 128, 0, 1, 128, 128, 128, 128, 2, // |
| 0, 1, 2, 128, 128, 128, 128, 3, // |
| 128, 128, 128, 0, 128, 128, 128, 1, // |
| 0, 128, 128, 1, 128, 128, 128, 2, // |
| 128, 0, 128, 1, 128, 128, 128, 2, // |
| 0, 1, 128, 2, 128, 128, 128, 3, // |
| 128, 128, 0, 1, 128, 128, 128, 2, // |
| 0, 128, 1, 2, 128, 128, 128, 3, // |
| 128, 0, 1, 2, 128, 128, 128, 3, // |
| 0, 1, 2, 3, 128, 128, 128, 4, // |
| 128, 128, 128, 128, 0, 128, 128, 1, // |
| 0, 128, 128, 128, 1, 128, 128, 2, // |
| 128, 0, 128, 128, 1, 128, 128, 2, // |
| 0, 1, 128, 128, 2, 128, 128, 3, // |
| 128, 128, 0, 128, 1, 128, 128, 2, // |
| 0, 128, 1, 128, 2, 128, 128, 3, // |
| 128, 0, 1, 128, 2, 128, 128, 3, // |
| 0, 1, 2, 128, 3, 128, 128, 4, // |
| 128, 128, 128, 0, 1, 128, 128, 2, // |
| 0, 128, 128, 1, 2, 128, 128, 3, // |
| 128, 0, 128, 1, 2, 128, 128, 3, // |
| 0, 1, 128, 2, 3, 128, 128, 4, // |
| 128, 128, 0, 1, 2, 128, 128, 3, // |
| 0, 128, 1, 2, 3, 128, 128, 4, // |
| 128, 0, 1, 2, 3, 128, 128, 4, // |
| 0, 1, 2, 3, 4, 128, 128, 5, // |
| 128, 128, 128, 128, 128, 0, 128, 1, // |
| 0, 128, 128, 128, 128, 1, 128, 2, // |
| 128, 0, 128, 128, 128, 1, 128, 2, // |
| 0, 1, 128, 128, 128, 2, 128, 3, // |
| 128, 128, 0, 128, 128, 1, 128, 2, // |
| 0, 128, 1, 128, 128, 2, 128, 3, // |
| 128, 0, 1, 128, 128, 2, 128, 3, // |
| 0, 1, 2, 128, 128, 3, 128, 4, // |
| 128, 128, 128, 0, 128, 1, 128, 2, // |
| 0, 128, 128, 1, 128, 2, 128, 3, // |
| 128, 0, 128, 1, 128, 2, 128, 3, // |
| 0, 1, 128, 2, 128, 3, 128, 4, // |
| 128, 128, 0, 1, 128, 2, 128, 3, // |
| 0, 128, 1, 2, 128, 3, 128, 4, // |
| 128, 0, 1, 2, 128, 3, 128, 4, // |
| 0, 1, 2, 3, 128, 4, 128, 5, // |
| 128, 128, 128, 128, 0, 1, 128, 2, // |
| 0, 128, 128, 128, 1, 2, 128, 3, // |
| 128, 0, 128, 128, 1, 2, 128, 3, // |
| 0, 1, 128, 128, 2, 3, 128, 4, // |
| 128, 128, 0, 128, 1, 2, 128, 3, // |
| 0, 128, 1, 128, 2, 3, 128, 4, // |
| 128, 0, 1, 128, 2, 3, 128, 4, // |
| 0, 1, 2, 128, 3, 4, 128, 5, // |
| 128, 128, 128, 0, 1, 2, 128, 3, // |
| 0, 128, 128, 1, 2, 3, 128, 4, // |
| 128, 0, 128, 1, 2, 3, 128, 4, // |
| 0, 1, 128, 2, 3, 4, 128, 5, // |
| 128, 128, 0, 1, 2, 3, 128, 4, // |
| 0, 128, 1, 2, 3, 4, 128, 5, // |
| 128, 0, 1, 2, 3, 4, 128, 5, // |
| 0, 1, 2, 3, 4, 5, 128, 6, // |
| 128, 128, 128, 128, 128, 128, 0, 1, // |
| 0, 128, 128, 128, 128, 128, 1, 2, // |
| 128, 0, 128, 128, 128, 128, 1, 2, // |
| 0, 1, 128, 128, 128, 128, 2, 3, // |
| 128, 128, 0, 128, 128, 128, 1, 2, // |
| 0, 128, 1, 128, 128, 128, 2, 3, // |
| 128, 0, 1, 128, 128, 128, 2, 3, // |
| 0, 1, 2, 128, 128, 128, 3, 4, // |
| 128, 128, 128, 0, 128, 128, 1, 2, // |
| 0, 128, 128, 1, 128, 128, 2, 3, // |
| 128, 0, 128, 1, 128, 128, 2, 3, // |
| 0, 1, 128, 2, 128, 128, 3, 4, // |
| 128, 128, 0, 1, 128, 128, 2, 3, // |
| 0, 128, 1, 2, 128, 128, 3, 4, // |
| 128, 0, 1, 2, 128, 128, 3, 4, // |
| 0, 1, 2, 3, 128, 128, 4, 5, // |
| 128, 128, 128, 128, 0, 128, 1, 2, // |
| 0, 128, 128, 128, 1, 128, 2, 3, // |
| 128, 0, 128, 128, 1, 128, 2, 3, // |
| 0, 1, 128, 128, 2, 128, 3, 4, // |
| 128, 128, 0, 128, 1, 128, 2, 3, // |
| 0, 128, 1, 128, 2, 128, 3, 4, // |
| 128, 0, 1, 128, 2, 128, 3, 4, // |
| 0, 1, 2, 128, 3, 128, 4, 5, // |
| 128, 128, 128, 0, 1, 128, 2, 3, // |
| 0, 128, 128, 1, 2, 128, 3, 4, // |
| 128, 0, 128, 1, 2, 128, 3, 4, // |
| 0, 1, 128, 2, 3, 128, 4, 5, // |
| 128, 128, 0, 1, 2, 128, 3, 4, // |
| 0, 128, 1, 2, 3, 128, 4, 5, // |
| 128, 0, 1, 2, 3, 128, 4, 5, // |
| 0, 1, 2, 3, 4, 128, 5, 6, // |
| 128, 128, 128, 128, 128, 0, 1, 2, // |
| 0, 128, 128, 128, 128, 1, 2, 3, // |
| 128, 0, 128, 128, 128, 1, 2, 3, // |
| 0, 1, 128, 128, 128, 2, 3, 4, // |
| 128, 128, 0, 128, 128, 1, 2, 3, // |
| 0, 128, 1, 128, 128, 2, 3, 4, // |
| 128, 0, 1, 128, 128, 2, 3, 4, // |
| 0, 1, 2, 128, 128, 3, 4, 5, // |
| 128, 128, 128, 0, 128, 1, 2, 3, // |
| 0, 128, 128, 1, 128, 2, 3, 4, // |
| 128, 0, 128, 1, 128, 2, 3, 4, // |
| 0, 1, 128, 2, 128, 3, 4, 5, // |
| 128, 128, 0, 1, 128, 2, 3, 4, // |
| 0, 128, 1, 2, 128, 3, 4, 5, // |
| 128, 0, 1, 2, 128, 3, 4, 5, // |
| 0, 1, 2, 3, 128, 4, 5, 6, // |
| 128, 128, 128, 128, 0, 1, 2, 3, // |
| 0, 128, 128, 128, 1, 2, 3, 4, // |
| 128, 0, 128, 128, 1, 2, 3, 4, // |
| 0, 1, 128, 128, 2, 3, 4, 5, // |
| 128, 128, 0, 128, 1, 2, 3, 4, // |
| 0, 128, 1, 128, 2, 3, 4, 5, // |
| 128, 0, 1, 128, 2, 3, 4, 5, // |
| 0, 1, 2, 128, 3, 4, 5, 6, // |
| 128, 128, 128, 0, 1, 2, 3, 4, // |
| 0, 128, 128, 1, 2, 3, 4, 5, // |
| 128, 0, 128, 1, 2, 3, 4, 5, // |
| 0, 1, 128, 2, 3, 4, 5, 6, // |
| 128, 128, 0, 1, 2, 3, 4, 5, // |
| 0, 128, 1, 2, 3, 4, 5, 6, // |
| 128, 0, 1, 2, 3, 4, 5, 6, // |
| 0, 1, 2, 3, 4, 5, 6, 7}; |
| return Load(du8, table + mask_bits * 8); |
| } |
| |
| template <class D, HWY_IF_T_SIZE_D(D, 1)> |
| HWY_INLINE svuint8_t LaneIndicesFromByteIndices(D, svuint8_t idx) { |
| return idx; |
| } |
| template <class D, class DU = RebindToUnsigned<D>, HWY_IF_NOT_T_SIZE_D(D, 1)> |
| HWY_INLINE VFromD<DU> LaneIndicesFromByteIndices(D, svuint8_t idx) { |
| return PromoteTo(DU(), idx); |
| } |
| |
| // General case when we don't know the vector size, 8 elements at a time. |
| template <class V> |
| HWY_INLINE V ExpandLoop(V v, svbool_t mask) { |
| const DFromV<V> d; |
| using T = TFromV<V>; |
| uint8_t mask_bytes[256 / 8]; |
| StoreMaskBits(d, mask, mask_bytes); |
| |
| // ShiftLeftLanes is expensive, so we're probably better off storing to memory |
| // and loading the final result. |
| alignas(16) T out[2 * MaxLanes(d)]; |
| |
| svbool_t next = svpfalse_b(); |
| size_t input_consumed = 0; |
| const V iota = Iota(d, 0); |
| for (size_t i = 0; i < Lanes(d); i += 8) { |
| uint64_t mask_bits = mask_bytes[i / 8]; |
| |
| // We want to skip past the v lanes already consumed. There is no |
| // instruction for variable-shift-reg, but we can splice. |
| const V vH = detail::Splice(v, v, next); |
| input_consumed += PopCount(mask_bits); |
| next = detail::GeN(iota, ConvertScalarTo<T>(input_consumed)); |
| |
| const auto idx = detail::LaneIndicesFromByteIndices( |
| d, detail::IndicesForExpandFromBits(mask_bits)); |
| const V expand = TableLookupLanes(vH, idx); |
| StoreU(expand, d, out + i); |
| } |
| return LoadU(d, out); |
| } |
| |
| } // namespace detail |
| |
| template <class V, HWY_IF_T_SIZE_V(V, 1)> |
| HWY_API V Expand(V v, svbool_t mask) { |
| #if HWY_TARGET == HWY_SVE2_128 || HWY_IDE |
| const DFromV<V> d; |
| uint8_t mask_bytes[256 / 8]; |
| StoreMaskBits(d, mask, mask_bytes); |
| const uint64_t maskL = mask_bytes[0]; |
| const uint64_t maskH = mask_bytes[1]; |
| |
| // We want to skip past the v bytes already consumed by expandL. There is no |
| // instruction for shift-reg by variable bytes, but we can splice. Instead of |
| // GeN, Not(FirstN()) would also work. |
| using T = TFromV<V>; |
| const T countL = static_cast<T>(PopCount(maskL)); |
| const V vH = detail::Splice(v, v, detail::GeN(Iota(d, 0), countL)); |
| |
| const svuint8_t idxL = detail::IndicesForExpandFromBits(maskL); |
| const svuint8_t idxH = detail::IndicesForExpandFromBits(maskH); |
| return Combine(d, TableLookupLanes(vH, idxH), TableLookupLanes(v, idxL)); |
| #else |
| return detail::ExpandLoop(v, mask); |
| #endif |
| } |
| |
| template <class V, HWY_IF_T_SIZE_V(V, 2)> |
| HWY_API V Expand(V v, svbool_t mask) { |
| #if HWY_TARGET == HWY_SVE2_128 || HWY_IDE // 16x8 |
| const DFromV<V> d; |
| const RebindToUnsigned<decltype(d)> du16; |
| const Rebind<uint8_t, decltype(d)> du8; |
| // Convert mask into bitfield via horizontal sum (faster than ORV) of 8 bits. |
| // Pre-multiply by N so we can use it as an offset for Load. |
| const svuint16_t bits = Shl(Set(du16, 1), Iota(du16, 3)); |
| const size_t offset = detail::SumOfLanesM(mask, bits); |
| |
| // Storing as 8-bit reduces table size from 4 KiB to 2 KiB. We cannot apply |
| // the nibble trick used below because not all indices fit within one lane. |
| alignas(16) static constexpr uint8_t table[8 * 256] = { |
| // PrintExpand16x8LaneTables |
| 255, 255, 255, 255, 255, 255, 255, 255, // |
| 0, 255, 255, 255, 255, 255, 255, 255, // |
| 255, 0, 255, 255, 255, 255, 255, 255, // |
| 0, 1, 255, 255, 255, 255, 255, 255, // |
| 255, 255, 0, 255, 255, 255, 255, 255, // |
| 0, 255, 1, 255, 255, 255, 255, 255, // |
| 255, 0, 1, 255, 255, 255, 255, 255, // |
| 0, 1, 2, 255, 255, 255, 255, 255, // |
| 255, 255, 255, 0, 255, 255, 255, 255, // |
| 0, 255, 255, 1, 255, 255, 255, 255, // |
| 255, 0, 255, 1, 255, 255, 255, 255, // |
| 0, 1, 255, 2, 255, 255, 255, 255, // |
| 255, 255, 0, 1, 255, 255, 255, 255, // |
| 0, 255, 1, 2, 255, 255, 255, 255, // |
| 255, 0, 1, 2, 255, 255, 255, 255, // |
| 0, 1, 2, 3, 255, 255, 255, 255, // |
| 255, 255, 255, 255, 0, 255, 255, 255, // |
| 0, 255, 255, 255, 1, 255, 255, 255, // |
| 255, 0, 255, 255, 1, 255, 255, 255, // |
| 0, 1, 255, 255, 2, 255, 255, 255, // |
| 255, 255, 0, 255, 1, 255, 255, 255, // |
| 0, 255, 1, 255, 2, 255, 255, 255, // |
| 255, 0, 1, 255, 2, 255, 255, 255, // |
| 0, 1, 2, 255, 3, 255, 255, 255, // |
| 255, 255, 255, 0, 1, 255, 255, 255, // |
| 0, 255, 255, 1, 2, 255, 255, 255, // |
| 255, 0, 255, 1, 2, 255, 255, 255, // |
| 0, 1, 255, 2, 3, 255, 255, 255, // |
| 255, 255, 0, 1, 2, 255, 255, 255, // |
| 0, 255, 1, 2, 3, 255, 255, 255, // |
| 255, 0, 1, 2, 3, 255, 255, 255, // |
| 0, 1, 2, 3, 4, 255, 255, 255, // |
| 255, 255, 255, 255, 255, 0, 255, 255, // |
| 0, 255, 255, 255, 255, 1, 255, 255, // |
| 255, 0, 255, 255, 255, 1, 255, 255, // |
| 0, 1, 255, 255, 255, 2, 255, 255, // |
| 255, 255, 0, 255, 255, 1, 255, 255, // |
| 0, 255, 1, 255, 255, 2, 255, 255, // |
| 255, 0, 1, 255, 255, 2, 255, 255, // |
| 0, 1, 2, 255, 255, 3, 255, 255, // |
| 255, 255, 255, 0, 255, 1, 255, 255, // |
| 0, 255, 255, 1, 255, 2, 255, 255, // |
| 255, 0, 255, 1, 255, 2, 255, 255, // |
| 0, 1, 255, 2, 255, 3, 255, 255, // |
| 255, 255, 0, 1, 255, 2, 255, 255, // |
| 0, 255, 1, 2, 255, 3, 255, 255, // |
| 255, 0, 1, 2, 255, 3, 255, 255, // |
| 0, 1, 2, 3, 255, 4, 255, 255, // |
| 255, 255, 255, 255, 0, 1, 255, 255, // |
| 0, 255, 255, 255, 1, 2, 255, 255, // |
| 255, 0, 255, 255, 1, 2, 255, 255, // |
| 0, 1, 255, 255, 2, 3, 255, 255, // |
| 255, 255, 0, 255, 1, 2, 255, 255, // |
| 0, 255, 1, 255, 2, 3, 255, 255, // |
| 255, 0, 1, 255, 2, 3, 255, 255, // |
| 0, 1, 2, 255, 3, 4, 255, 255, // |
| 255, 255, 255, 0, 1, 2, 255, 255, // |
| 0, 255, 255, 1, 2, 3, 255, 255, // |
| 255, 0, 255, 1, 2, 3, 255, 255, // |
| 0, 1, 255, 2, 3, 4, 255, 255, // |
| 255, 255, 0, 1, 2, 3, 255, 255, // |
| 0, 255, 1, 2, 3, 4, 255, 255, // |
| 255, 0, 1, 2, 3, 4, 255, 255, // |
| 0, 1, 2, 3, 4, 5, 255, 255, // |
| 255, 255, 255, 255, 255, 255, 0, 255, // |
| 0, 255, 255, 255, 255, 255, 1, 255, // |
| 255, 0, 255, 255, 255, 255, 1, 255, // |
| 0, 1, 255, 255, 255, 255, 2, 255, // |
| 255, 255, 0, 255, 255, 255, 1, 255, // |
| 0, 255, 1, 255, 255, 255, 2, 255, // |
| 255, 0, 1, 255, 255, 255, 2, 255, // |
| 0, 1, 2, 255, 255, 255, 3, 255, // |
| 255, 255, 255, 0, 255, 255, 1, 255, // |
| 0, 255, 255, 1, 255, 255, 2, 255, // |
| 255, 0, 255, 1, 255, 255, 2, 255, // |
| 0, 1, 255, 2, 255, 255, 3, 255, // |
| 255, 255, 0, 1, 255, 255, 2, 255, // |
| 0, 255, 1, 2, 255, 255, 3, 255, // |
| 255, 0, 1, 2, 255, 255, 3, 255, // |
| 0, 1, 2, 3, 255, 255, 4, 255, // |
| 255, 255, 255, 255, 0, 255, 1, 255, // |
| 0, 255, 255, 255, 1, 255, 2, 255, // |
| 255, 0, 255, 255, 1, 255, 2, 255, // |
| 0, 1, 255, 255, 2, 255, 3, 255, // |
| 255, 255, 0, 255, 1, 255, 2, 255, // |
| 0, 255, 1, 255, 2, 255, 3, 255, // |
| 255, 0, 1, 255, 2, 255, 3, 255, // |
| 0, 1, 2, 255, 3, 255, 4, 255, // |
| 255, 255, 255, 0, 1, 255, 2, 255, // |
| 0, 255, 255, 1, 2, 255, 3, 255, // |
| 255, 0, 255, 1, 2, 255, 3, 255, // |
| 0, 1, 255, 2, 3, 255, 4, 255, // |
| 255, 255, 0, 1, 2, 255, 3, 255, // |
| 0, 255, 1, 2, 3, 255, 4, 255, // |
| 255, 0, 1, 2, 3, 255, 4, 255, // |
| 0, 1, 2, 3, 4, 255, 5, 255, // |
| 255, 255, 255, 255, 255, 0, 1, 255, // |
| 0, 255, 255, 255, 255, 1, 2, 255, // |
| 255, 0, 255, 255, 255, 1, 2, 255, // |
| 0, 1, 255, 255, 255, 2, 3, 255, // |
| 255, 255, 0, 255, 255, 1, 2, 255, // |
| 0, 255, 1, 255, 255, 2, 3, 255, // |
| 255, 0, 1, 255, 255, 2, 3, 255, // |
| 0, 1, 2, 255, 255, 3, 4, 255, // |
| 255, 255, 255, 0, 255, 1, 2, 255, // |
| 0, 255, 255, 1, 255, 2, 3, 255, // |
| 255, 0, 255, 1, 255, 2, 3, 255, // |
| 0, 1, 255, 2, 255, 3, 4, 255, // |
| 255, 255, 0, 1, 255, 2, 3, 255, // |
| 0, 255, 1, 2, 255, 3, 4, 255, // |
| 255, 0, 1, 2, 255, 3, 4, 255, // |
| 0, 1, 2, 3, 255, 4, 5, 255, // |
| 255, 255, 255, 255, 0, 1, 2, 255, // |
| 0, 255, 255, 255, 1, 2, 3, 255, // |
| 255, 0, 255, 255, 1, 2, 3, 255, // |
| 0, 1, 255, 255, 2, 3, 4, 255, // |
| 255, 255, 0, 255, 1, 2, 3, 255, // |
| 0, 255, 1, 255, 2, 3, 4, 255, // |
| 255, 0, 1, 255, 2, 3, 4, 255, // |
| 0, 1, 2, 255, 3, 4, 5, 255, // |
| 255, 255, 255, 0, 1, 2, 3, 255, // |
| 0, 255, 255, 1, 2, 3, 4, 255, // |
| 255, 0, 255, 1, 2, 3, 4, 255, // |
| 0, 1, 255, 2, 3, 4, 5, 255, // |
| 255, 255, 0, 1, 2, 3, 4, 255, // |
| 0, 255, 1, 2, 3, 4, 5, 255, // |
| 255, 0, 1, 2, 3, 4, 5, 255, // |
| 0, 1, 2, 3, 4, 5, 6, 255, // |
| 255, 255, 255, 255, 255, 255, 255, 0, // |
| 0, 255, 255, 255, 255, 255, 255, 1, // |
| 255, 0, 255, 255, 255, 255, 255, 1, // |
| 0, 1, 255, 255, 255, 255, 255, 2, // |
| 255, 255, 0, 255, 255, 255, 255, 1, // |
| 0, 255, 1, 255, 255, 255, 255, 2, // |
| 255, 0, 1, 255, 255, 255, 255, 2, // |
| 0, 1, 2, 255, 255, 255, 255, 3, // |
| 255, 255, 255, 0, 255, 255, 255, 1, // |
| 0, 255, 255, 1, 255, 255, 255, 2, // |
| 255, 0, 255, 1, 255, 255, 255, 2, // |
| 0, 1, 255, 2, 255, 255, 255, 3, // |
| 255, 255, 0, 1, 255, 255, 255, 2, // |
| 0, 255, 1, 2, 255, 255, 255, 3, // |
| 255, 0, 1, 2, 255, 255, 255, 3, // |
| 0, 1, 2, 3, 255, 255, 255, 4, // |
| 255, 255, 255, 255, 0, 255, 255, 1, // |
| 0, 255, 255, 255, 1, 255, 255, 2, // |
| 255, 0, 255, 255, 1, 255, 255, 2, // |
| 0, 1, 255, 255, 2, 255, 255, 3, // |
| 255, 255, 0, 255, 1, 255, 255, 2, // |
| 0, 255, 1, 255, 2, 255, 255, 3, // |
| 255, 0, 1, 255, 2, 255, 255, 3, // |
| 0, 1, 2, 255, 3, 255, 255, 4, // |
| 255, 255, 255, 0, 1, 255, 255, 2, // |
| 0, 255, 255, 1, 2, 255, 255, 3, // |
| 255, 0, 255, 1, 2, 255, 255, 3, // |
| 0, 1, 255, 2, 3, 255, 255, 4, // |
| 255, 255, 0, 1, 2, 255, 255, 3, // |
| 0, 255, 1, 2, 3, 255, 255, 4, // |
| 255, 0, 1, 2, 3, 255, 255, 4, // |
| 0, 1, 2, 3, 4, 255, 255, 5, // |
| 255, 255, 255, 255, 255, 0, 255, 1, // |
| 0, 255, 255, 255, 255, 1, 255, 2, // |
| 255, 0, 255, 255, 255, 1, 255, 2, // |
| 0, 1, 255, 255, 255, 2, 255, 3, // |
| 255, 255, 0, 255, 255, 1, 255, 2, // |
| 0, 255, 1, 255, 255, 2, 255, 3, // |
| 255, 0, 1, 255, 255, 2, 255, 3, // |
| 0, 1, 2, 255, 255, 3, 255, 4, // |
| 255, 255, 255, 0, 255, 1, 255, 2, // |
| 0, 255, 255, 1, 255, 2, 255, 3, // |
| 255, 0, 255, 1, 255, 2, 255, 3, // |
| 0, 1, 255, 2, 255, 3, 255, 4, // |
| 255, 255, 0, 1, 255, 2, 255, 3, // |
| 0, 255, 1, 2, 255, 3, 255, 4, // |
| 255, 0, 1, 2, 255, 3, 255, 4, // |
| 0, 1, 2, 3, 255, 4, 255, 5, // |
| 255, 255, 255, 255, 0, 1, 255, 2, // |
| 0, 255, 255, 255, 1, 2, 255, 3, // |
| 255, 0, 255, 255, 1, 2, 255, 3, // |
| 0, 1, 255, 255, 2, 3, 255, 4, // |
| 255, 255, 0, 255, 1, 2, 255, 3, // |
| 0, 255, 1, 255, 2, 3, 255, 4, // |
| 255, 0, 1, 255, 2, 3, 255, 4, // |
| 0, 1, 2, 255, 3, 4, 255, 5, // |
| 255, 255, 255, 0, 1, 2, 255, 3, // |
| 0, 255, 255, 1, 2, 3, 255, 4, // |
| 255, 0, 255, 1, 2, 3, 255, 4, // |
| 0, 1, 255, 2, 3, 4, 255, 5, // |
| 255, 255, 0, 1, 2, 3, 255, 4, // |
| 0, 255, 1, 2, 3, 4, 255, 5, // |
| 255, 0, 1, 2, 3, 4, 255, 5, // |
| 0, 1, 2, 3, 4, 5, 255, 6, // |
| 255, 255, 255, 255, 255, 255, 0, 1, // |
| 0, 255, 255, 255, 255, 255, 1, 2, // |
| 255, 0, 255, 255, 255, 255, 1, 2, // |
| 0, 1, 255, 255, 255, 255, 2, 3, // |
| 255, 255, 0, 255, 255, 255, 1, 2, // |
| 0, 255, 1, 255, 255, 255, 2, 3, // |
| 255, 0, 1, 255, 255, 255, 2, 3, // |
| 0, 1, 2, 255, 255, 255, 3, 4, // |
| 255, 255, 255, 0, 255, 255, 1, 2, // |
| 0, 255, 255, 1, 255, 255, 2, 3, // |
| 255, 0, 255, 1, 255, 255, 2, 3, // |
| 0, 1, 255, 2, 255, 255, 3, 4, // |
| 255, 255, 0, 1, 255, 255, 2, 3, // |
| 0, 255, 1, 2, 255, 255, 3, 4, // |
| 255, 0, 1, 2, 255, 255, 3, 4, // |
| 0, 1, 2, 3, 255, 255, 4, 5, // |
| 255, 255, 255, 255, 0, 255, 1, 2, // |
| 0, 255, 255, 255, 1, 255, 2, 3, // |
| 255, 0, 255, 255, 1, 255, 2, 3, // |
| 0, 1, 255, 255, 2, 255, 3, 4, // |
| 255, 255, 0, 255, 1, 255, 2, 3, // |
| 0, 255, 1, 255, 2, 255, 3, 4, // |
| 255, 0, 1, 255, 2, 255, 3, 4, // |
| 0, 1, 2, 255, 3, 255, 4, 5, // |
| 255, 255, 255, 0, 1, 255, 2, 3, // |
| 0, 255, 255, 1, 2, 255, 3, 4, // |
| 255, 0, 255, 1, 2, 255, 3, 4, // |
| 0, 1, 255, 2, 3, 255, 4, 5, // |
| 255, 255, 0, 1, 2, 255, 3, 4, // |
| 0, 255, 1, 2, 3, 255, 4, 5, // |
| 255, 0, 1, 2, 3, 255, 4, 5, // |
| 0, 1, 2, 3, 4, 255, 5, 6, // |
| 255, 255, 255, 255, 255, 0, 1, 2, // |
| 0, 255, 255, 255, 255, 1, 2, 3, // |
| 255, 0, 255, 255, 255, 1, 2, 3, // |
| 0, 1, 255, 255, 255, 2, 3, 4, // |
| 255, 255, 0, 255, 255, 1, 2, 3, // |
| 0, 255, 1, 255, 255, 2, 3, 4, // |
| 255, 0, 1, 255, 255, 2, 3, 4, // |
| 0, 1, 2, 255, 255, 3, 4, 5, // |
| 255, 255, 255, 0, 255, 1, 2, 3, // |
| 0, 255, 255, 1, 255, 2, 3, 4, // |
| 255, 0, 255, 1, 255, 2, 3, 4, // |
| 0, 1, 255, 2, 255, 3, 4, 5, // |
| 255, 255, 0, 1, 255, 2, 3, 4, // |
| 0, 255, 1, 2, 255, 3, 4, 5, // |
| 255, 0, 1, 2, 255, 3, 4, 5, // |
| 0, 1, 2, 3, 255, 4, 5, 6, // |
| 255, 255, 255, 255, 0, 1, 2, 3, // |
| 0, 255, 255, 255, 1, 2, 3, 4, // |
| 255, 0, 255, 255, 1, 2, 3, 4, // |
| 0, 1, 255, 255, 2, 3, 4, 5, // |
| 255, 255, 0, 255, 1, 2, 3, 4, // |
| 0, 255, 1, 255, 2, 3, 4, 5, // |
| 255, 0, 1, 255, 2, 3, 4, 5, // |
| 0, 1, 2, 255, 3, 4, 5, 6, // |
| 255, 255, 255, 0, 1, 2, 3, 4, // |
| 0, 255, 255, 1, 2, 3, 4, 5, // |
| 255, 0, 255, 1, 2, 3, 4, 5, // |
| 0, 1, 255, 2, 3, 4, 5, 6, // |
| 255, 255, 0, 1, 2, 3, 4, 5, // |
| 0, 255, 1, 2, 3, 4, 5, 6, // |
| 255, 0, 1, 2, 3, 4, 5, 6, // |
| 0, 1, 2, 3, 4, 5, 6, 7}; |
| const svuint16_t indices = PromoteTo(du16, Load(du8, table + offset)); |
| return TableLookupLanes(v, indices); // already zeros mask=false lanes |
| #else |
| return detail::ExpandLoop(v, mask); |
| #endif |
| } |
| |
| template <class V, HWY_IF_T_SIZE_V(V, 4)> |
| HWY_API V Expand(V v, svbool_t mask) { |
| #if HWY_TARGET == HWY_SVE_256 || HWY_IDE // 32x8 |
| const DFromV<V> d; |
| const RebindToUnsigned<decltype(d)> du32; |
| // Convert mask into bitfield via horizontal sum (faster than ORV). |
| const svuint32_t bits = Shl(Set(du32, 1), Iota(du32, 0)); |
| const size_t code = detail::SumOfLanesM(mask, bits); |
| |
| alignas(16) constexpr uint32_t packed_array[256] = { |
| // PrintExpand32x8. |
| 0xffffffff, 0xfffffff0, 0xffffff0f, 0xffffff10, 0xfffff0ff, 0xfffff1f0, |
| 0xfffff10f, 0xfffff210, 0xffff0fff, 0xffff1ff0, 0xffff1f0f, 0xffff2f10, |
| 0xffff10ff, 0xffff21f0, 0xffff210f, 0xffff3210, 0xfff0ffff, 0xfff1fff0, |
| 0xfff1ff0f, 0xfff2ff10, 0xfff1f0ff, 0xfff2f1f0, 0xfff2f10f, 0xfff3f210, |
| 0xfff10fff, 0xfff21ff0, 0xfff21f0f, 0xfff32f10, 0xfff210ff, 0xfff321f0, |
| 0xfff3210f, 0xfff43210, 0xff0fffff, 0xff1ffff0, 0xff1fff0f, 0xff2fff10, |
| 0xff1ff0ff, 0xff2ff1f0, 0xff2ff10f, 0xff3ff210, 0xff1f0fff, 0xff2f1ff0, |
| 0xff2f1f0f, 0xff3f2f10, 0xff2f10ff, 0xff3f21f0, 0xff3f210f, 0xff4f3210, |
| 0xff10ffff, 0xff21fff0, 0xff21ff0f, 0xff32ff10, 0xff21f0ff, 0xff32f1f0, |
| 0xff32f10f, 0xff43f210, 0xff210fff, 0xff321ff0, 0xff321f0f, 0xff432f10, |
| 0xff3210ff, 0xff4321f0, 0xff43210f, 0xff543210, 0xf0ffffff, 0xf1fffff0, |
| 0xf1ffff0f, 0xf2ffff10, 0xf1fff0ff, 0xf2fff1f0, 0xf2fff10f, 0xf3fff210, |
| 0xf1ff0fff, 0xf2ff1ff0, 0xf2ff1f0f, 0xf3ff2f10, 0xf2ff10ff, 0xf3ff21f0, |
| 0xf3ff210f, 0xf4ff3210, 0xf1f0ffff, 0xf2f1fff0, 0xf2f1ff0f, 0xf3f2ff10, |
| 0xf2f1f0ff, 0xf3f2f1f0, 0xf3f2f10f, 0xf4f3f210, 0xf2f10fff, 0xf3f21ff0, |
| 0xf3f21f0f, 0xf4f32f10, 0xf3f210ff, 0xf4f321f0, 0xf4f3210f, 0xf5f43210, |
| 0xf10fffff, 0xf21ffff0, 0xf21fff0f, 0xf32fff10, 0xf21ff0ff, 0xf32ff1f0, |
| 0xf32ff10f, 0xf43ff210, 0xf21f0fff, 0xf32f1ff0, 0xf32f1f0f, 0xf43f2f10, |
| 0xf32f10ff, 0xf43f21f0, 0xf43f210f, 0xf54f3210, 0xf210ffff, 0xf321fff0, |
| 0xf321ff0f, 0xf432ff10, 0xf321f0ff, 0xf432f1f0, 0xf432f10f, 0xf543f210, |
| 0xf3210fff, 0xf4321ff0, 0xf4321f0f, 0xf5432f10, 0xf43210ff, 0xf54321f0, |
| 0xf543210f, 0xf6543210, 0x0fffffff, 0x1ffffff0, 0x1fffff0f, 0x2fffff10, |
| 0x1ffff0ff, 0x2ffff1f0, 0x2ffff10f, 0x3ffff210, 0x1fff0fff, 0x2fff1ff0, |
| 0x2fff1f0f, 0x3fff2f10, 0x2fff10ff, 0x3fff21f0, 0x3fff210f, 0x4fff3210, |
| 0x1ff0ffff, 0x2ff1fff0, 0x2ff1ff0f, 0x3ff2ff10, 0x2ff1f0ff, 0x3ff2f1f0, |
| 0x3ff2f10f, 0x4ff3f210, 0x2ff10fff, 0x3ff21ff0, 0x3ff21f0f, 0x4ff32f10, |
| 0x3ff210ff, 0x4ff321f0, 0x4ff3210f, 0x5ff43210, 0x1f0fffff, 0x2f1ffff0, |
| 0x2f1fff0f, 0x3f2fff10, 0x2f1ff0ff, 0x3f2ff1f0, 0x3f2ff10f, 0x4f3ff210, |
| 0x2f1f0fff, 0x3f2f1ff0, 0x3f2f1f0f, 0x4f3f2f10, 0x3f2f10ff, 0x4f3f21f0, |
| 0x4f3f210f, 0x5f4f3210, 0x2f10ffff, 0x3f21fff0, 0x3f21ff0f, 0x4f32ff10, |
| 0x3f21f0ff, 0x4f32f1f0, 0x4f32f10f, 0x5f43f210, 0x3f210fff, 0x4f321ff0, |
| 0x4f321f0f, 0x5f432f10, 0x4f3210ff, 0x5f4321f0, 0x5f43210f, 0x6f543210, |
| 0x10ffffff, 0x21fffff0, 0x21ffff0f, 0x32ffff10, 0x21fff0ff, 0x32fff1f0, |
| 0x32fff10f, 0x43fff210, 0x21ff0fff, 0x32ff1ff0, 0x32ff1f0f, 0x43ff2f10, |
| 0x32ff10ff, 0x43ff21f0, 0x43ff210f, 0x54ff3210, 0x21f0ffff, 0x32f1fff0, |
| 0x32f1ff0f, 0x43f2ff10, 0x32f1f0ff, 0x43f2f1f0, 0x43f2f10f, 0x54f3f210, |
| 0x32f10fff, 0x43f21ff0, 0x43f21f0f, 0x54f32f10, 0x43f210ff, 0x54f321f0, |
| 0x54f3210f, 0x65f43210, 0x210fffff, 0x321ffff0, 0x321fff0f, 0x432fff10, |
| 0x321ff0ff, 0x432ff1f0, 0x432ff10f, 0x543ff210, 0x321f0fff, 0x432f1ff0, |
| 0x432f1f0f, 0x543f2f10, 0x432f10ff, 0x543f21f0, 0x543f210f, 0x654f3210, |
| 0x3210ffff, 0x4321fff0, 0x4321ff0f, 0x5432ff10, 0x4321f0ff, 0x5432f1f0, |
| 0x5432f10f, 0x6543f210, 0x43210fff, 0x54321ff0, 0x54321f0f, 0x65432f10, |
| 0x543210ff, 0x654321f0, 0x6543210f, 0x76543210}; |
| |
| // For lane i, shift the i-th 4-bit index down and mask with 0xF because |
| // svtbl zeros outputs if the index is out of bounds. |
| const svuint32_t packed = Set(du32, packed_array[code]); |
| const svuint32_t indices = detail::AndN(Shr(packed, svindex_u32(0, 4)), 0xF); |
| return TableLookupLanes(v, indices); // already zeros mask=false lanes |
| #elif HWY_TARGET == HWY_SVE2_128 // 32x4 |
| const DFromV<V> d; |
| const RebindToUnsigned<decltype(d)> du32; |
| // Convert mask into bitfield via horizontal sum (faster than ORV). |
| const svuint32_t bits = Shl(Set(du32, 1), Iota(du32, 0)); |
| const size_t offset = detail::SumOfLanesM(mask, bits); |
| |
| alignas(16) constexpr uint32_t packed_array[16] = { |
| // PrintExpand64x4Nibble - same for 32x4. |
| 0x0000ffff, 0x0000fff0, 0x0000ff0f, 0x0000ff10, 0x0000f0ff, 0x0000f1f0, |
| 0x0000f10f, 0x0000f210, 0x00000fff, 0x00001ff0, 0x00001f0f, 0x00002f10, |
| 0x000010ff, 0x000021f0, 0x0000210f, 0x00003210}; |
| |
| // For lane i, shift the i-th 4-bit index down and mask with 0xF because |
| // svtbl zeros outputs if the index is out of bounds. |
| const svuint32_t packed = Set(du32, packed_array[offset]); |
| const svuint32_t indices = detail::AndN(Shr(packed, svindex_u32(0, 4)), 0xF); |
| return TableLookupLanes(v, indices); // already zeros mask=false lanes |
| #else |
| return detail::ExpandLoop(v, mask); |
| #endif |
| } |
| |
| template <class V, HWY_IF_T_SIZE_V(V, 8)> |
| HWY_API V Expand(V v, svbool_t mask) { |
| #if HWY_TARGET == HWY_SVE_256 || HWY_IDE // 64x4 |
| const DFromV<V> d; |
| const RebindToUnsigned<decltype(d)> du64; |
| |
| // Convert mask into bitfield via horizontal sum (faster than ORV) of masked |
| // bits 1, 2, 4, 8. Pre-multiply by N so we can use it as an offset for |
| // SetTableIndices. |
| const svuint64_t bits = Shl(Set(du64, 1), Iota(du64, 2)); |
| const size_t offset = detail::SumOfLanesM(mask, bits); |
| |
| alignas(16) static constexpr uint64_t table[4 * 16] = { |
| // PrintExpand64x4Tables - small enough to store uncompressed. |
| 255, 255, 255, 255, 0, 255, 255, 255, 255, 0, 255, 255, 0, 1, 255, 255, |
| 255, 255, 0, 255, 0, 255, 1, 255, 255, 0, 1, 255, 0, 1, 2, 255, |
| 255, 255, 255, 0, 0, 255, 255, 1, 255, 0, 255, 1, 0, 1, 255, 2, |
| 255, 255, 0, 1, 0, 255, 1, 2, 255, 0, 1, 2, 0, 1, 2, 3}; |
| // This already zeros mask=false lanes. |
| return TableLookupLanes(v, SetTableIndices(d, table + offset)); |
| #elif HWY_TARGET == HWY_SVE2_128 // 64x2 |
| // Same as Compress, just zero out the mask=false lanes. |
| return IfThenElseZero(mask, Compress(v, mask)); |
| #else |
| return detail::ExpandLoop(v, mask); |
| #endif |
| } |
| |
| // ------------------------------ LoadExpand |
| |
| template <class D> |
| HWY_API VFromD<D> LoadExpand(MFromD<D> mask, D d, |
| const TFromD<D>* HWY_RESTRICT unaligned) { |
| return Expand(LoadU(d, unaligned), mask); |
| } |
| |
| // ------------------------------ MulEven (InterleaveEven) |
| |
| #if HWY_SVE_HAVE_2 |
| namespace detail { |
| #define HWY_SVE_MUL_EVEN(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| NAME(HWY_SVE_V(BASE, HALF) a, HWY_SVE_V(BASE, HALF) b) { \ |
| return sv##OP##_##CHAR##BITS(a, b); \ |
| } |
| |
| HWY_SVE_FOREACH_UI16(HWY_SVE_MUL_EVEN, MulEvenNative, mullb) |
| HWY_SVE_FOREACH_UI32(HWY_SVE_MUL_EVEN, MulEvenNative, mullb) |
| HWY_SVE_FOREACH_UI64(HWY_SVE_MUL_EVEN, MulEvenNative, mullb) |
| HWY_SVE_FOREACH_UI16(HWY_SVE_MUL_EVEN, MulOddNative, mullt) |
| HWY_SVE_FOREACH_UI32(HWY_SVE_MUL_EVEN, MulOddNative, mullt) |
| HWY_SVE_FOREACH_UI64(HWY_SVE_MUL_EVEN, MulOddNative, mullt) |
| #undef HWY_SVE_MUL_EVEN |
| } // namespace detail |
| #endif |
| |
| template <class V, class DW = RepartitionToWide<DFromV<V>>, |
| HWY_IF_T_SIZE_ONE_OF_V(V, (1 << 1) | (1 << 2) | (1 << 4))> |
| HWY_API VFromD<DW> MulEven(const V a, const V b) { |
| #if HWY_SVE_HAVE_2 |
| return BitCast(DW(), detail::MulEvenNative(a, b)); |
| #else |
| const auto lo = Mul(a, b); |
| const auto hi = MulHigh(a, b); |
| return BitCast(DW(), detail::InterleaveEven(lo, hi)); |
| #endif |
| } |
| |
| template <class V, class DW = RepartitionToWide<DFromV<V>>, |
| HWY_IF_T_SIZE_ONE_OF_V(V, (1 << 1) | (1 << 2) | (1 << 4))> |
| HWY_API VFromD<DW> MulOdd(const V a, const V b) { |
| #if HWY_SVE_HAVE_2 |
| return BitCast(DW(), detail::MulOddNative(a, b)); |
| #else |
| const auto lo = Mul(a, b); |
| const auto hi = MulHigh(a, b); |
| return BitCast(DW(), detail::InterleaveOdd(lo, hi)); |
| #endif |
| } |
| |
| HWY_API svint64_t MulEven(const svint64_t a, const svint64_t b) { |
| const auto lo = Mul(a, b); |
| const auto hi = MulHigh(a, b); |
| return detail::InterleaveEven(lo, hi); |
| } |
| |
| HWY_API svuint64_t MulEven(const svuint64_t a, const svuint64_t b) { |
| const auto lo = Mul(a, b); |
| const auto hi = MulHigh(a, b); |
| return detail::InterleaveEven(lo, hi); |
| } |
| |
| HWY_API svint64_t MulOdd(const svint64_t a, const svint64_t b) { |
| const auto lo = Mul(a, b); |
| const auto hi = MulHigh(a, b); |
| return detail::InterleaveOdd(lo, hi); |
| } |
| |
| HWY_API svuint64_t MulOdd(const svuint64_t a, const svuint64_t b) { |
| const auto lo = Mul(a, b); |
| const auto hi = MulHigh(a, b); |
| return detail::InterleaveOdd(lo, hi); |
| } |
| |
| // ------------------------------ PairwiseAdd/PairwiseSub |
| #if HWY_TARGET != HWY_SCALAR |
| #if HWY_SVE_HAVE_2 || HWY_IDE |
| |
| #ifdef HWY_NATIVE_PAIRWISE_ADD |
| #undef HWY_NATIVE_PAIRWISE_ADD |
| #else |
| #define HWY_NATIVE_PAIRWISE_ADD |
| #endif |
| |
| namespace detail { |
| #define HWY_SVE_SV_PAIRWISE_ADD(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| template <size_t N, int kPow2> \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /*d*/, HWY_SVE_V(BASE, BITS) a, \ |
| HWY_SVE_V(BASE, BITS) b) { \ |
| return sv##OP##_##CHAR##BITS##_m(HWY_SVE_PTRUE(BITS), a, b); \ |
| } |
| |
| HWY_SVE_FOREACH(HWY_SVE_SV_PAIRWISE_ADD, PairwiseAdd, addp) |
| #undef HWY_SVE_SV_PAIRWISE_ADD |
| } // namespace detail |
| |
| // Pairwise add returning interleaved output of a and b |
| template <class D, class V, HWY_IF_LANES_GT_D(D, 1)> |
| HWY_API V PairwiseAdd(D d, V a, V b) { |
| return detail::PairwiseAdd(d, a, b); |
| } |
| |
| #endif // HWY_SVE_HAVE_2 |
| #endif // HWY_TARGET != HWY_SCALAR |
| |
| // ------------------------------ WidenMulPairwiseAdd |
| |
| template <size_t N, int kPow2> |
| HWY_API svfloat32_t WidenMulPairwiseAdd(Simd<float, N, kPow2> df, VBF16 a, |
| VBF16 b) { |
| #if HWY_SVE_HAVE_F32_TO_BF16C |
| const svfloat32_t even = svbfmlalb_f32(Zero(df), a, b); |
| return svbfmlalt_f32(even, a, b); |
| #else |
| return MulAdd(PromoteEvenTo(df, a), PromoteEvenTo(df, b), |
| Mul(PromoteOddTo(df, a), PromoteOddTo(df, b))); |
| #endif // HWY_SVE_HAVE_BF16_FEATURE |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API svint32_t WidenMulPairwiseAdd(Simd<int32_t, N, kPow2> d32, svint16_t a, |
| svint16_t b) { |
| #if HWY_SVE_HAVE_2 |
| (void)d32; |
| return svmlalt_s32(svmullb_s32(a, b), a, b); |
| #else |
| return MulAdd(PromoteEvenTo(d32, a), PromoteEvenTo(d32, b), |
| Mul(PromoteOddTo(d32, a), PromoteOddTo(d32, b))); |
| #endif |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API svuint32_t WidenMulPairwiseAdd(Simd<uint32_t, N, kPow2> d32, |
| svuint16_t a, svuint16_t b) { |
| #if HWY_SVE_HAVE_2 |
| (void)d32; |
| return svmlalt_u32(svmullb_u32(a, b), a, b); |
| #else |
| return MulAdd(PromoteEvenTo(d32, a), PromoteEvenTo(d32, b), |
| Mul(PromoteOddTo(d32, a), PromoteOddTo(d32, b))); |
| #endif |
| } |
| |
| // ------------------------------ SatWidenMulPairwiseAccumulate |
| #if HWY_SVE_HAVE_2 |
| #define HWY_SVE_SAT_MUL_WIDEN_PW_ACC_SVE_2(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| template <size_t N, int kPow2> \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| NAME(HWY_SVE_D(BASE, BITS, N, kPow2) dw, HWY_SVE_V(BASE, HALF) a, \ |
| HWY_SVE_V(BASE, HALF) b, HWY_SVE_V(BASE, BITS) sum) { \ |
| auto product = svmlalt_##CHAR##BITS(svmullb_##CHAR##BITS(a, b), a, b); \ |
| const auto mul_overflow = IfThenElseZero( \ |
| Eq(product, Set(dw, LimitsMin<int##BITS##_t>())), Set(dw, -1)); \ |
| return SaturatedAdd(Sub(sum, And(BroadcastSignBit(sum), mul_overflow)), \ |
| Add(product, mul_overflow)); \ |
| } |
| HWY_SVE_FOREACH_UI16(HWY_SVE_SAT_MUL_WIDEN_PW_ACC_SVE_2, |
| SatWidenMulPairwiseAccumulate, _) |
| HWY_SVE_FOREACH_UI32(HWY_SVE_SAT_MUL_WIDEN_PW_ACC_SVE_2, |
| SatWidenMulPairwiseAccumulate, _) |
| HWY_SVE_FOREACH_UI64(HWY_SVE_SAT_MUL_WIDEN_PW_ACC_SVE_2, |
| SatWidenMulPairwiseAccumulate, _) |
| |
| #undef HWY_SVE_SAT_MUL_WIDEN_PW_ACC_SVE_2 |
| #endif |
| |
| // ------------------------------ SatWidenMulAccumFixedPoint |
| |
| #if HWY_SVE_HAVE_2 |
| |
| #ifdef HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT |
| #undef HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT |
| #else |
| #define HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT |
| #endif |
| |
| template <class DI32, HWY_IF_I32_D(DI32)> |
| HWY_API VFromD<DI32> SatWidenMulAccumFixedPoint(DI32 /*di32*/, |
| VFromD<Rebind<int16_t, DI32>> a, |
| VFromD<Rebind<int16_t, DI32>> b, |
| VFromD<DI32> sum) { |
| return svqdmlalb_s32(sum, detail::ZipLowerSame(a, a), |
| detail::ZipLowerSame(b, b)); |
| } |
| |
| #endif // HWY_SVE_HAVE_2 |
| |
| // ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) |
| |
| #if HWY_SVE_HAVE_BF16_FEATURE |
| |
| // NOTE: we currently do not use SVE BFDOT for bf16 ReorderWidenMulAccumulate |
| // because, apparently unlike NEON, it uses round to odd unless the additional |
| // FEAT_EBF16 feature is available and enabled. |
| #ifdef HWY_NATIVE_MUL_EVEN_BF16 |
| #undef HWY_NATIVE_MUL_EVEN_BF16 |
| #else |
| #define HWY_NATIVE_MUL_EVEN_BF16 |
| #endif |
| |
| template <size_t N, int kPow2> |
| HWY_API svfloat32_t MulEvenAdd(Simd<float, N, kPow2> /* d */, VBF16 a, VBF16 b, |
| const svfloat32_t c) { |
| return svbfmlalb_f32(c, a, b); |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API svfloat32_t MulOddAdd(Simd<float, N, kPow2> /* d */, VBF16 a, VBF16 b, |
| const svfloat32_t c) { |
| return svbfmlalt_f32(c, a, b); |
| } |
| |
| #endif // HWY_SVE_HAVE_BF16_FEATURE |
| |
| template <size_t N, int kPow2> |
| HWY_API svint32_t ReorderWidenMulAccumulate(Simd<int32_t, N, kPow2> d32, |
| svint16_t a, svint16_t b, |
| const svint32_t sum0, |
| svint32_t& sum1) { |
| #if HWY_SVE_HAVE_2 |
| (void)d32; |
| sum1 = svmlalt_s32(sum1, a, b); |
| return svmlalb_s32(sum0, a, b); |
| #else |
| // Lane order within sum0/1 is undefined, hence we can avoid the |
| // longer-latency lane-crossing PromoteTo by using PromoteEvenTo. |
| sum1 = MulAdd(PromoteOddTo(d32, a), PromoteOddTo(d32, b), sum1); |
| return MulAdd(PromoteEvenTo(d32, a), PromoteEvenTo(d32, b), sum0); |
| #endif |
| } |
| |
| template <size_t N, int kPow2> |
| HWY_API svuint32_t ReorderWidenMulAccumulate(Simd<uint32_t, N, kPow2> d32, |
| svuint16_t a, svuint16_t b, |
| const svuint32_t sum0, |
| svuint32_t& sum1) { |
| #if HWY_SVE_HAVE_2 |
| (void)d32; |
| sum1 = svmlalt_u32(sum1, a, b); |
| return svmlalb_u32(sum0, a, b); |
| #else |
| // Lane order within sum0/1 is undefined, hence we can avoid the |
| // longer-latency lane-crossing PromoteTo by using PromoteEvenTo. |
| sum1 = MulAdd(PromoteOddTo(d32, a), PromoteOddTo(d32, b), sum1); |
| return MulAdd(PromoteEvenTo(d32, a), PromoteEvenTo(d32, b), sum0); |
| #endif |
| } |
| |
| // ------------------------------ RearrangeToOddPlusEven |
| template <class VW> |
| HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW sum1) { |
| // sum0 is the sum of bottom/even lanes and sum1 of top/odd lanes. |
| return Add(sum0, sum1); |
| } |
| |
| // ------------------------------ SumOfMulQuadAccumulate |
| |
| #ifdef HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE |
| #undef HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE |
| #else |
| #define HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE |
| #endif |
| |
| template <class DI32, HWY_IF_I32_D(DI32)> |
| HWY_API VFromD<DI32> SumOfMulQuadAccumulate(DI32 /*di32*/, svint8_t a, |
| svint8_t b, svint32_t sum) { |
| return svdot_s32(sum, a, b); |
| } |
| |
| #ifdef HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE |
| #undef HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE |
| #else |
| #define HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE |
| #endif |
| |
| template <class DU32, HWY_IF_U32_D(DU32)> |
| HWY_API VFromD<DU32> SumOfMulQuadAccumulate(DU32 /*du32*/, svuint8_t a, |
| svuint8_t b, svuint32_t sum) { |
| return svdot_u32(sum, a, b); |
| } |
| |
| #ifdef HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE |
| #undef HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE |
| #else |
| #define HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE |
| #endif |
| |
| template <class DI32, HWY_IF_I32_D(DI32)> |
| HWY_API VFromD<DI32> SumOfMulQuadAccumulate(DI32 di32, svuint8_t a_u, |
| svint8_t b_i, svint32_t sum) { |
| // TODO: use svusdot_u32 on SVE targets that require support for both SVE2 |
| // and SVE I8MM. |
| |
| const RebindToUnsigned<decltype(di32)> du32; |
| const Repartition<uint8_t, decltype(di32)> du8; |
| |
| const auto b_u = BitCast(du8, b_i); |
| const auto result_sum0 = svdot_u32(BitCast(du32, sum), a_u, b_u); |
| const auto result_sum1 = |
| ShiftLeft<8>(svdot_u32(Zero(du32), a_u, ShiftRight<7>(b_u))); |
| |
| return BitCast(di32, Sub(result_sum0, result_sum1)); |
| } |
| |
| #ifdef HWY_NATIVE_I16_I16_SUMOFMULQUADACCUMULATE |
| #undef HWY_NATIVE_I16_I16_SUMOFMULQUADACCUMULATE |
| #else |
| #define HWY_NATIVE_I16_I16_SUMOFMULQUADACCUMULATE |
| #endif |
| |
| template <class DI64, HWY_IF_I64_D(DI64)> |
| HWY_API VFromD<DI64> SumOfMulQuadAccumulate(DI64 /*di64*/, svint16_t a, |
| svint16_t b, svint64_t sum) { |
| return svdot_s64(sum, a, b); |
| } |
| |
| #ifdef HWY_NATIVE_U16_U16_SUMOFMULQUADACCUMULATE |
| #undef HWY_NATIVE_U16_U16_SUMOFMULQUADACCUMULATE |
| #else |
| #define HWY_NATIVE_U16_U16_SUMOFMULQUADACCUMULATE |
| #endif |
| |
| template <class DU64, HWY_IF_U64_D(DU64)> |
| HWY_API VFromD<DU64> SumOfMulQuadAccumulate(DU64 /*du64*/, svuint16_t a, |
| svuint16_t b, svuint64_t sum) { |
| return svdot_u64(sum, a, b); |
| } |
| |
| // ------------------------------ MulComplex* / MaskedMulComplex* |
| |
| // Per-target flag to prevent generic_ops-inl.h from defining MulComplex*. |
| #ifdef HWY_NATIVE_CPLX |
| #undef HWY_NATIVE_CPLX |
| #else |
| #define HWY_NATIVE_CPLX |
| #endif |
| |
| template <class V, HWY_IF_NOT_UNSIGNED(TFromV<V>)> |
| HWY_API V ComplexConj(V a) { |
| return OddEven(Neg(a), a); |
| } |
| |
| namespace detail { |
| #define HWY_SVE_CPLX_FMA_ROT(BASE, CHAR, BITS, HALF, NAME, OP, ROT) \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| NAME##ROT(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b, \ |
| HWY_SVE_V(BASE, BITS) c) { \ |
| return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), a, b, c, ROT); \ |
| } \ |
| HWY_API HWY_SVE_V(BASE, BITS) \ |
| NAME##Z##ROT(svbool_t m, HWY_SVE_V(BASE, BITS) a, \ |
| HWY_SVE_V(BASE, BITS) b, HWY_SVE_V(BASE, BITS) c) { \ |
| return sv##OP##_##CHAR##BITS##_z(m, a, b, c, ROT); \ |
| } |
| |
| #define HWY_SVE_CPLX_FMA(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_SVE_CPLX_FMA_ROT(BASE, CHAR, BITS, HALF, NAME, OP, 0) \ |
| HWY_SVE_CPLX_FMA_ROT(BASE, CHAR, BITS, HALF, NAME, OP, 90) \ |
| HWY_SVE_CPLX_FMA_ROT(BASE, CHAR, BITS, HALF, NAME, OP, 180) \ |
| HWY_SVE_CPLX_FMA_ROT(BASE, CHAR, BITS, HALF, NAME, OP, 270) |
| |
| // Only SVE2 has complex multiply add for integer types |
| // and these do not include masked variants |
| HWY_SVE_FOREACH_F(HWY_SVE_CPLX_FMA, ComplexMulAdd, cmla) |
| #undef HWY_SVE_CPLX_FMA |
| #undef HWY_SVE_CPLX_FMA_ROT |
| } // namespace detail |
| |
| template <class V, class M, HWY_IF_FLOAT_V(V)> |
| HWY_API V MaskedMulComplexConjAdd(M mask, V a, V b, V c) { |
| const V t = detail::ComplexMulAddZ0(mask, c, b, a); |
| return detail::ComplexMulAddZ270(mask, t, b, a); |
| } |
| |
| template <class V, class M, HWY_IF_FLOAT_V(V)> |
| HWY_API V MaskedMulComplexConj(M mask, V a, V b) { |
| return MaskedMulComplexConjAdd(mask, a, b, Zero(DFromV<V>())); |
| } |
| |
| template <class V, HWY_IF_FLOAT_V(V)> |
| HWY_API V MulComplexAdd(V a, V b, V c) { |
| return detail::ComplexMulAdd90(detail::ComplexMulAdd0(c, a, b), a, b); |
| } |
| |
| template <class V, HWY_IF_FLOAT_V(V)> |
| HWY_API V MulComplex(V a, V b) { |
| return MulComplexAdd(a, b, Zero(DFromV<V>())); |
| } |
| |
| template <class V, class M, HWY_IF_FLOAT_V(V)> |
| HWY_API V MaskedMulComplexOr(V no, M mask, V a, V b) { |
| return IfThenElse(mask, MulComplex(a, b), no); |
| } |
| |
| template <class V, HWY_IF_FLOAT_V(V)> |
| HWY_API V MulComplexConjAdd(V a, V b, V c) { |
| return detail::ComplexMulAdd270(detail::ComplexMulAdd0(c, b, a), b, a); |
| } |
| |
| template <class V, HWY_IF_FLOAT_V(V)> |
| HWY_API V MulComplexConj(V a, V b) { |
| return MulComplexConjAdd(a, b, Zero(DFromV<V>())); |
| } |
| |
| // TODO SVE2 does have intrinsics for integers but not masked variants |
| template <class V, HWY_IF_NOT_FLOAT_V(V)> |
| HWY_API V MulComplex(V a, V b) { |
| // a = u + iv, b = x + iy |
| const auto u = DupEven(a); |
| const auto v = DupOdd(a); |
| const auto x = DupEven(b); |
| const auto y = DupOdd(b); |
| |
| return OddEven(MulAdd(u, y, Mul(v, x)), Sub(Mul(u, x), Mul(v, y))); |
| } |
| |
| template <class V, HWY_IF_NOT_FLOAT_V(V)> |
| HWY_API V MulComplexConj(V a, V b) { |
| // a = u + iv, b = x + iy |
| const auto u = DupEven(a); |
| const auto v = DupOdd(a); |
| const auto x = DupEven(b); |
| const auto y = DupOdd(b); |
| |
| return OddEven(Sub(Mul(v, x), Mul(u, y)), MulAdd(u, x, Mul(v, y))); |
| } |
| |
| template <class V, HWY_IF_NOT_FLOAT_V(V)> |
| HWY_API V MulComplexAdd(V a, V b, V c) { |
| return Add(MulComplex(a, b), c); |
| } |
| |
| template <class V, HWY_IF_NOT_FLOAT_V(V)> |
| HWY_API V MulComplexConjAdd(V a, V b, V c) { |
| return Add(MulComplexConj(a, b), c); |
| } |
| |
| template <class V, class M, HWY_IF_NOT_FLOAT_V(V)> |
| HWY_API V MaskedMulComplexConjAdd(M mask, V a, V b, V c) { |
| return IfThenElseZero(mask, MulComplexConjAdd(a, b, c)); |
| } |
| |
| template <class V, class M, HWY_IF_NOT_FLOAT_V(V)> |
| HWY_API V MaskedMulComplexConj(M mask, V a, V b) { |
| return IfThenElseZero(mask, MulComplexConj(a, b)); |
| } |
| |
| template <class V, class M, HWY_IF_NOT_FLOAT_V(V)> |
| HWY_API V MaskedMulComplexOr(V no, M mask, V a, V b) { |
| return IfThenElse(mask, MulComplex(a, b), no); |
| } |
| |
| // ------------------------------ AESRound / CLMul |
| |
| // Static dispatch with -march=armv8-a+sve2+aes, or dynamic dispatch WITHOUT a |
| // baseline, in which case we check for AES support at runtime. |
| #if defined(__ARM_FEATURE_SVE2_AES) || \ |
| (HWY_SVE_HAVE_2 && HWY_HAVE_RUNTIME_DISPATCH && HWY_BASELINE_SVE2 == 0) |
| |
| // Per-target flag to prevent generic_ops-inl.h from defining AESRound. |
| #ifdef HWY_NATIVE_AES |
| #undef HWY_NATIVE_AES |
| #else |
| #define HWY_NATIVE_AES |
| #endif |
| |
| HWY_API svuint8_t AESRound(svuint8_t state, svuint8_t round_key) { |
| // It is not clear whether E and MC fuse like they did on NEON. |
| return Xor(svaesmc_u8(svaese_u8(state, svdup_n_u8(0))), round_key); |
| } |
| |
| HWY_API svuint8_t AESLastRound(svuint8_t state, svuint8_t round_key) { |
| return Xor(svaese_u8(state, svdup_n_u8(0)), round_key); |
| } |
| |
| HWY_API svuint8_t AESInvMixColumns(svuint8_t state) { |
| return svaesimc_u8(state); |
| } |
| |
| HWY_API svuint8_t AESRoundInv(svuint8_t state, svuint8_t round_key) { |
| return Xor(svaesimc_u8(svaesd_u8(state, svdup_n_u8(0))), round_key); |
| } |
| |
| HWY_API svuint8_t AESLastRoundInv(svuint8_t state, svuint8_t round_key) { |
| return Xor(svaesd_u8(state, svdup_n_u8(0)), round_key); |
| } |
| |
| template <uint8_t kRcon> |
| HWY_API svuint8_t AESKeyGenAssist(svuint8_t v) { |
| alignas(16) static constexpr uint8_t kRconXorMask[16] = { |
| 0, kRcon, 0, 0, 0, 0, 0, 0, 0, kRcon, 0, 0, 0, 0, 0, 0}; |
| alignas(16) static constexpr uint8_t kRotWordShuffle[16] = { |
| 0, 13, 10, 7, 1, 14, 11, 4, 8, 5, 2, 15, 9, 6, 3, 12}; |
| const DFromV<decltype(v)> d; |
| const Repartition<uint32_t, decltype(d)> du32; |
| const auto w13 = BitCast(d, DupOdd(BitCast(du32, v))); |
| const auto sub_word_result = AESLastRound(w13, LoadDup128(d, kRconXorMask)); |
| return TableLookupBytes(sub_word_result, LoadDup128(d, kRotWordShuffle)); |
| } |
| |
| HWY_API svuint64_t CLMulLower(const svuint64_t a, const svuint64_t b) { |
| return svpmullb_pair(a, b); |
| } |
| |
| HWY_API svuint64_t CLMulUpper(const svuint64_t a, const svuint64_t b) { |
| return svpmullt_pair(a, b); |
| } |
| |
| #endif // __ARM_FEATURE_SVE2_AES |
| |
| // ------------------------------ Lt128 |
| |
| namespace detail { |
| #define HWY_SVE_DUP(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| template <size_t N, int kPow2> \ |
| HWY_API svbool_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /*d*/, svbool_t m) { \ |
| return sv##OP##_b##BITS(m, m); \ |
| } |
| |
| HWY_SVE_FOREACH_U(HWY_SVE_DUP, DupEvenB, trn1) // actually for bool |
| HWY_SVE_FOREACH_U(HWY_SVE_DUP, DupOddB, trn2) // actually for bool |
| #undef HWY_SVE_DUP |
| |
| #if HWY_TARGET == HWY_SVE_256 || HWY_IDE |
| template <class D> |
| HWY_INLINE svuint64_t Lt128Vec(D d, const svuint64_t a, const svuint64_t b) { |
| static_assert(IsSame<TFromD<D>, uint64_t>(), "D must be u64"); |
| const svbool_t eqHx = Eq(a, b); // only odd lanes used |
| // Convert to vector: more pipelines can execute vector TRN* instructions |
| // than the predicate version. |
| const svuint64_t ltHL = VecFromMask(d, Lt(a, b)); |
| // Move into upper lane: ltL if the upper half is equal, otherwise ltH. |
| // Requires an extra IfThenElse because INSR, EXT, TRN2 are unpredicated. |
| const svuint64_t ltHx = IfThenElse(eqHx, DupEven(ltHL), ltHL); |
| // Duplicate upper lane into lower. |
| return DupOdd(ltHx); |
| } |
| #endif |
| } // namespace detail |
| |
| template <class D> |
| HWY_INLINE svbool_t Lt128(D d, const svuint64_t a, const svuint64_t b) { |
| #if HWY_TARGET == HWY_SVE_256 |
| return MaskFromVec(detail::Lt128Vec(d, a, b)); |
| #else |
| static_assert(IsSame<TFromD<D>, uint64_t>(), "D must be u64"); |
| const svbool_t eqHx = Eq(a, b); // only odd lanes used |
| const svbool_t ltHL = Lt(a, b); |
| // Move into upper lane: ltL if the upper half is equal, otherwise ltH. |
| const svbool_t ltHx = svsel_b(eqHx, detail::DupEvenB(d, ltHL), ltHL); |
| // Duplicate upper lane into lower. |
| return detail::DupOddB(d, ltHx); |
| #endif // HWY_TARGET != HWY_SVE_256 |
| } |
| |
| // ------------------------------ Lt128Upper |
| |
| template <class D> |
| HWY_INLINE svbool_t Lt128Upper(D d, svuint64_t a, svuint64_t b) { |
| static_assert(IsSame<TFromD<D>, uint64_t>(), "D must be u64"); |
| const svbool_t ltHL = Lt(a, b); |
| return detail::DupOddB(d, ltHL); |
| } |
| |
| // ------------------------------ Eq128, Ne128 |
| |
| #if HWY_TARGET == HWY_SVE_256 || HWY_IDE |
| namespace detail { |
| |
| template <class D> |
| HWY_INLINE svuint64_t Eq128Vec(D d, const svuint64_t a, const svuint64_t b) { |
| static_assert(IsSame<TFromD<D>, uint64_t>(), "D must be u64"); |
| // Convert to vector: more pipelines can execute vector TRN* instructions |
| // than the predicate version. |
| const svuint64_t eqHL = VecFromMask(d, Eq(a, b)); |
| // Duplicate upper and lower. |
| const svuint64_t eqHH = DupOdd(eqHL); |
| const svuint64_t eqLL = DupEven(eqHL); |
| return And(eqLL, eqHH); |
| } |
| |
| template <class D> |
| HWY_INLINE svuint64_t Ne128Vec(D d, const svuint64_t a, const svuint64_t b) { |
| static_assert(IsSame<TFromD<D>, uint64_t>(), "D must be u64"); |
| // Convert to vector: more pipelines can execute vector TRN* instructions |
| // than the predicate version. |
| const svuint64_t neHL = VecFromMask(d, Ne(a, b)); |
| // Duplicate upper and lower. |
| const svuint64_t neHH = DupOdd(neHL); |
| const svuint64_t neLL = DupEven(neHL); |
| return Or(neLL, neHH); |
| } |
| |
| } // namespace detail |
| #endif |
| |
| template <class D> |
| HWY_INLINE svbool_t Eq128(D d, const svuint64_t a, const svuint64_t b) { |
| #if HWY_TARGET == HWY_SVE_256 |
| return MaskFromVec(detail::Eq128Vec(d, a, b)); |
| #else |
| static_assert(IsSame<TFromD<D>, uint64_t>(), "D must be u64"); |
| const svbool_t eqHL = Eq(a, b); |
| const svbool_t eqHH = detail::DupOddB(d, eqHL); |
| const svbool_t eqLL = detail::DupEvenB(d, eqHL); |
| return And(eqLL, eqHH); |
| #endif // HWY_TARGET != HWY_SVE_256 |
| } |
| |
| template <class D> |
| HWY_INLINE svbool_t Ne128(D d, const svuint64_t a, const svuint64_t b) { |
| #if HWY_TARGET == HWY_SVE_256 |
| return MaskFromVec(detail::Ne128Vec(d, a, b)); |
| #else |
| static_assert(IsSame<TFromD<D>, uint64_t>(), "D must be u64"); |
| const svbool_t neHL = Ne(a, b); |
| const svbool_t neHH = detail::DupOddB(d, neHL); |
| const svbool_t neLL = detail::DupEvenB(d, neHL); |
| return Or(neLL, neHH); |
| #endif // HWY_TARGET != HWY_SVE_256 |
| } |
| |
| // ------------------------------ Eq128Upper, Ne128Upper |
| |
| template <class D> |
| HWY_INLINE svbool_t Eq128Upper(D d, svuint64_t a, svuint64_t b) { |
| static_assert(IsSame<TFromD<D>, uint64_t>(), "D must be u64"); |
| const svbool_t eqHL = Eq(a, b); |
| return detail::DupOddB(d, eqHL); |
| } |
| |
| template <class D> |
| HWY_INLINE svbool_t Ne128Upper(D d, svuint64_t a, svuint64_t b) { |
| static_assert(IsSame<TFromD<D>, uint64_t>(), "D must be u64"); |
| const svbool_t neHL = Ne(a, b); |
| return detail::DupOddB(d, neHL); |
| } |
| |
| // ------------------------------ Min128, Max128 (Lt128) |
| |
| template <class D> |
| HWY_INLINE svuint64_t Min128(D d, const svuint64_t a, const svuint64_t b) { |
| #if HWY_TARGET == HWY_SVE_256 |
| return IfVecThenElse(detail::Lt128Vec(d, a, b), a, b); |
| #else |
| return IfThenElse(Lt128(d, a, b), a, b); |
| #endif |
| } |
| |
| template <class D> |
| HWY_INLINE svuint64_t Max128(D d, const svuint64_t a, const svuint64_t b) { |
| #if HWY_TARGET == HWY_SVE_256 |
| return IfVecThenElse(detail::Lt128Vec(d, b, a), a, b); |
| #else |
| return IfThenElse(Lt128(d, b, a), a, b); |
| #endif |
| } |
| |
| template <class D> |
| HWY_INLINE svuint64_t Min128Upper(D d, const svuint64_t a, const svuint64_t b) { |
| return IfThenElse(Lt128Upper(d, a, b), a, b); |
| } |
| |
| template <class D> |
| HWY_INLINE svuint64_t Max128Upper(D d, const svuint64_t a, const svuint64_t b) { |
| return IfThenElse(Lt128Upper(d, b, a), a, b); |
| } |
| |
| // -------------------- LeadingZeroCount, TrailingZeroCount, HighestSetBitIndex |
| |
| #ifdef HWY_NATIVE_LEADING_ZERO_COUNT |
| #undef HWY_NATIVE_LEADING_ZERO_COUNT |
| #else |
| #define HWY_NATIVE_LEADING_ZERO_COUNT |
| #endif |
| |
| #define HWY_SVE_LEADING_ZERO_COUNT(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ |
| const DFromV<decltype(v)> d; \ |
| return BitCast(d, sv##OP##_##CHAR##BITS##_x(detail::PTrue(d), v)); \ |
| } |
| |
| HWY_SVE_FOREACH_UI(HWY_SVE_LEADING_ZERO_COUNT, LeadingZeroCount, clz) |
| #undef HWY_SVE_LEADING_ZERO_COUNT |
| |
| template <class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V)> |
| HWY_API V TrailingZeroCount(V v) { |
| return LeadingZeroCount(ReverseBits(v)); |
| } |
| |
| template <class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V)> |
| HWY_API V HighestSetBitIndex(V v) { |
| const DFromV<decltype(v)> d; |
| using T = TFromD<decltype(d)>; |
| return BitCast(d, Sub(Set(d, T{sizeof(T) * 8 - 1}), LeadingZeroCount(v))); |
| } |
| |
| #ifdef HWY_NATIVE_MASKED_LEADING_ZERO_COUNT |
| #undef HWY_NATIVE_MASKED_LEADING_ZERO_COUNT |
| #else |
| #define HWY_NATIVE_MASKED_LEADING_ZERO_COUNT |
| #endif |
| |
| #define HWY_SVE_MASKED_LEADING_ZERO_COUNT(BASE, CHAR, BITS, HALF, NAME, OP) \ |
| HWY_API HWY_SVE_V(BASE, BITS) NAME(svbool_t m, HWY_SVE_V(BASE, BITS) v) { \ |
| const DFromV<decltype(v)> d; \ |
| return BitCast(d, sv##OP##_##CHAR##BITS##_z(m, v)); \ |
| } |
| |
| HWY_SVE_FOREACH_UI(HWY_SVE_MASKED_LEADING_ZERO_COUNT, MaskedLeadingZeroCount, |
| clz) |
| #undef HWY_SVE_LEADING_ZERO_COUNT |
| |
| // ================================================== END MACROS |
| #undef HWY_SVE_ALL_PTRUE |
| #undef HWY_SVE_D |
| #undef HWY_SVE_FOREACH |
| #undef HWY_SVE_FOREACH_BF16 |
| #undef HWY_SVE_FOREACH_BF16_UNCONDITIONAL |
| #undef HWY_SVE_FOREACH_F |
| #undef HWY_SVE_FOREACH_F16 |
| #undef HWY_SVE_FOREACH_F32 |
| #undef HWY_SVE_FOREACH_F3264 |
| #undef HWY_SVE_FOREACH_F64 |
| #undef HWY_SVE_FOREACH_I |
| #undef HWY_SVE_FOREACH_I08 |
| #undef HWY_SVE_FOREACH_I16 |
| #undef HWY_SVE_FOREACH_I32 |
| #undef HWY_SVE_FOREACH_I64 |
| #undef HWY_SVE_FOREACH_IF |
| #undef HWY_SVE_FOREACH_U |
| #undef HWY_SVE_FOREACH_U08 |
| #undef HWY_SVE_FOREACH_U16 |
| #undef HWY_SVE_FOREACH_U32 |
| #undef HWY_SVE_FOREACH_U64 |
| #undef HWY_SVE_FOREACH_UI |
| #undef HWY_SVE_FOREACH_UI08 |
| #undef HWY_SVE_FOREACH_UI16 |
| #undef HWY_SVE_FOREACH_UI32 |
| #undef HWY_SVE_FOREACH_UI64 |
| #undef HWY_SVE_FOREACH_UIF3264 |
| #undef HWY_SVE_HAVE_2 |
| #undef HWY_SVE_IF_EMULATED_D |
| #undef HWY_SVE_IF_NOT_EMULATED_D |
| #undef HWY_SVE_PTRUE |
| #undef HWY_SVE_RETV_ARGMVV |
| #undef HWY_SVE_RETV_ARGMVV_Z |
| #undef HWY_SVE_RETV_ARGMV_Z |
| #undef HWY_SVE_RETV_ARGMV |
| #undef HWY_SVE_RETV_ARGMVV_Z |
| #undef HWY_SVE_RETV_ARGPV |
| #undef HWY_SVE_RETV_ARGPVN |
| #undef HWY_SVE_RETV_ARGPVV |
| #undef HWY_SVE_RETV_ARGV |
| #undef HWY_SVE_RETV_ARGVN |
| #undef HWY_SVE_RETV_ARGMV_M |
| #undef HWY_SVE_RETV_ARGVV |
| #undef HWY_SVE_RETV_ARGVVV |
| #undef HWY_SVE_RETV_ARGMVVV_Z |
| #undef HWY_SVE_RETV_ARGMVVV |
| #undef HWY_SVE_T |
| #undef HWY_SVE_UNDEFINED |
| #undef HWY_SVE_V |
| |
| // NOLINTNEXTLINE(google-readability-namespace-comments) |
| } // namespace HWY_NAMESPACE |
| } // namespace hwy |
| HWY_AFTER_NAMESPACE(); |