| // Copyright 2022 Google LLC |
| // SPDX-License-Identifier: Apache-2.0 |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| // Per-target include guard |
| #if defined(HIGHWAY_HWY_CONTRIB_ALGO_FIND_INL_H_) == \ |
| defined(HWY_TARGET_TOGGLE) // NOLINT |
| #ifdef HIGHWAY_HWY_CONTRIB_ALGO_FIND_INL_H_ |
| #undef HIGHWAY_HWY_CONTRIB_ALGO_FIND_INL_H_ |
| #else |
| #define HIGHWAY_HWY_CONTRIB_ALGO_FIND_INL_H_ |
| #endif |
| |
| #include "third_party/highway/hwy/highway.h" |
| |
| HWY_BEFORE_NAMESPACE(); |
| namespace hwy { |
| namespace HWY_NAMESPACE { |
| |
| // Returns index of the first element equal to `value` in `in[0, count)`, or |
| // `count` if not found. |
| template <class D, typename T = TFromD<D>> |
| size_t Find(D d, T value, const T* HWY_RESTRICT in, size_t count) { |
| const size_t N = Lanes(d); |
| const Vec<D> broadcasted = Set(d, value); |
| |
| size_t i = 0; |
| if (count >= N) { |
| for (; i <= count - N; i += N) { |
| const intptr_t pos = FindFirstTrue(d, Eq(broadcasted, LoadU(d, in + i))); |
| if (pos >= 0) return i + static_cast<size_t>(pos); |
| } |
| } |
| |
| if (i != count) { |
| #if HWY_MEM_OPS_MIGHT_FAULT |
| // Scan single elements. |
| const CappedTag<T, 1> d1; |
| using V1 = Vec<decltype(d1)>; |
| const V1 broadcasted1 = Set(d1, GetLane(broadcasted)); |
| for (; i < count; ++i) { |
| if (AllTrue(d1, Eq(broadcasted1, LoadU(d1, in + i)))) { |
| return i; |
| } |
| } |
| #else |
| const size_t remaining = count - i; |
| HWY_DASSERT(0 != remaining && remaining < N); |
| const Mask<D> mask = FirstN(d, remaining); |
| const Vec<D> v = MaskedLoad(mask, d, in + i); |
| // Apply mask so that we don't 'find' the zero-padding from MaskedLoad. |
| const intptr_t pos = FindFirstTrue(d, And(Eq(broadcasted, v), mask)); |
| if (pos >= 0) return i + static_cast<size_t>(pos); |
| #endif // HWY_MEM_OPS_MIGHT_FAULT |
| } |
| |
| return count; // not found |
| } |
| |
| // Returns index of the first element in `in[0, count)` for which `func(d, vec)` |
| // returns true, otherwise `count`. |
| template <class D, class Func, typename T = TFromD<D>> |
| size_t FindIf(D d, const T* HWY_RESTRICT in, size_t count, const Func& func) { |
| const size_t N = Lanes(d); |
| |
| size_t i = 0; |
| if (count >= N) { |
| for (; i <= count - N; i += N) { |
| const intptr_t pos = FindFirstTrue(d, func(d, LoadU(d, in + i))); |
| if (pos >= 0) return i + static_cast<size_t>(pos); |
| } |
| } |
| |
| if (i != count) { |
| #if HWY_MEM_OPS_MIGHT_FAULT |
| // Scan single elements. |
| const CappedTag<T, 1> d1; |
| for (; i < count; ++i) { |
| if (AllTrue(d1, func(d1, LoadU(d1, in + i)))) { |
| return i; |
| } |
| } |
| #else |
| const size_t remaining = count - i; |
| HWY_DASSERT(0 != remaining && remaining < N); |
| const Mask<D> mask = FirstN(d, remaining); |
| const Vec<D> v = MaskedLoad(mask, d, in + i); |
| // Apply mask so that we don't 'find' the zero-padding from MaskedLoad. |
| const intptr_t pos = FindFirstTrue(d, And(func(d, v), mask)); |
| if (pos >= 0) return i + static_cast<size_t>(pos); |
| #endif // HWY_MEM_OPS_MIGHT_FAULT |
| } |
| |
| return count; // not found |
| } |
| |
| // NOLINTNEXTLINE(google-readability-namespace-comments) |
| } // namespace HWY_NAMESPACE |
| } // namespace hwy |
| HWY_AFTER_NAMESPACE(); |
| |
| #endif // HIGHWAY_HWY_CONTRIB_ALGO_FIND_INL_H_ |