warp_plane_hwy: Fix AVX512 edge filter loading

Previously, the full 4 blocks would be loaded from the blocks = 2 case,
leading to an out-of-bounds read.

Instead, propagate the tag so only 2 blocks are loaded in this case.

Change-Id: I59526b411638e864659eccce007af12abfa46a06
Bug: 447385711
diff --git a/av1/common/warp_plane_hwy.h b/av1/common/warp_plane_hwy.h
index 512aeef..61d5fc6 100644
--- a/av1/common/warp_plane_hwy.h
+++ b/av1/common/warp_plane_hwy.h
@@ -23,13 +23,12 @@
 
 namespace hn = hwy::HWY_NAMESPACE;
 
-constexpr hn::ScalableTag<uint8_t> uint8_tag;
-constexpr hn::ScalableTag<uint16_t> uint16_tag;
+constexpr hn::ScalableTag<uint8_t> uint8xN_tag;
+constexpr hn::ScalableTag<uint16_t> uint16xN_tag;
 
-constexpr hn::ScalableTag<int8_t> int8_tag;
-constexpr hn::ScalableTag<int16_t> int16_tag;
-constexpr hn::ScalableTag<int32_t> int32_tag;
-constexpr hn::ScalableTag<int64_t> int64_tag;
+constexpr hn::ScalableTag<int16_t> int16xN_tag;
+constexpr hn::ScalableTag<int32_t> int32xN_tag;
+constexpr hn::ScalableTag<int64_t> int64xN_tag;
 
 constexpr hn::CappedTag<uint8_t, 32> uint8x32_tag;
 constexpr hn::CappedTag<int16_t, 16> int16x16_tag;
@@ -46,9 +45,10 @@
 constexpr hn::FixedTag<int32_t, 4> int32x4_tag;
 constexpr hn::FixedTag<int64_t, 2> int64x2_tag;
 
-using IVec8 = hn::Vec<decltype(int8_tag)>;
-using IVec16 = hn::Vec<decltype(int16_tag)>;
-using IVec32 = hn::Vec<decltype(int32_tag)>;
+constexpr hn::ScalableTag<int8_t> coeff_tag;
+
+using IVec16 = hn::Vec<decltype(int16xN_tag)>;
+using IVec32 = hn::Vec<decltype(int32xN_tag)>;
 using IVec8x16 = hn::Vec<decltype(int8x16_tag)>;
 
 template <typename D>
@@ -57,27 +57,27 @@
                                             int8_t *HWY_RESTRICT coeff,
                                             const IVec16 round_const,
                                             const int shift, int row) {
-  constexpr hn::Repartition<int8_t, D> coeff_tag;
+  constexpr hn::Repartition<int8_t, D> int8_tag;
   constexpr hn::Repartition<int16_t, D> result_tag;
   constexpr hn::Repartition<uint16_t, D> unsigned_result_tag;
   // N.B. coeffs are stored to support the maximum vector width, which may not
   // be the vector width being filtered on now.
-  const auto coeff0 = hn::Load(coeff_tag, coeff + hn::MaxLanes(int8_tag) * 0);
-  const auto coeff1 = hn::Load(coeff_tag, coeff + hn::MaxLanes(int8_tag) * 1);
-  const auto coeff2 = hn::Load(coeff_tag, coeff + hn::MaxLanes(int8_tag) * 2);
-  const auto coeff3 = hn::Load(coeff_tag, coeff + hn::MaxLanes(int8_tag) * 3);
+  const auto coeff0 = hn::Load(int8_tag, coeff + hn::MaxLanes(coeff_tag) * 0);
+  const auto coeff1 = hn::Load(int8_tag, coeff + hn::MaxLanes(coeff_tag) * 1);
+  const auto coeff2 = hn::Load(int8_tag, coeff + hn::MaxLanes(coeff_tag) * 2);
+  const auto coeff3 = hn::Load(int8_tag, coeff + hn::MaxLanes(coeff_tag) * 3);
 
   const auto shuffle0 = hn::Dup128VecFromValues(
-      uint8_tag, 0, 2, 2, 4, 4, 6, 6, 8, 1, 3, 3, 5, 5, 7, 7, 9  //
+      uint8xN_tag, 0, 2, 2, 4, 4, 6, 6, 8, 1, 3, 3, 5, 5, 7, 7, 9  //
   );
   const auto shuffle1 = hn::Dup128VecFromValues(
-      uint8_tag, 4, 6, 6, 8, 8, 10, 10, 12, 5, 7, 7, 9, 9, 11, 11, 13  //
+      uint8xN_tag, 4, 6, 6, 8, 8, 10, 10, 12, 5, 7, 7, 9, 9, 11, 11, 13  //
   );
   const auto shuffle2 = hn::Dup128VecFromValues(
-      uint8_tag, 1, 3, 3, 5, 5, 7, 7, 9, 2, 4, 4, 6, 6, 8, 8, 10  //
+      uint8xN_tag, 1, 3, 3, 5, 5, 7, 7, 9, 2, 4, 4, 6, 6, 8, 8, 10  //
   );
   const auto shuffle3 = hn::Dup128VecFromValues(
-      uint8_tag, 5, 7, 7, 9, 9, 11, 11, 13, 6, 8, 8, 10, 10, 12, 12, 14  //
+      uint8xN_tag, 5, 7, 7, 9, 9, 11, 11, 13, 6, 8, 8, 10, 10, 12, 12, 14  //
   );
 
   const auto src_0 =
@@ -111,59 +111,84 @@
                    8);
 }
 
-HWY_ATTR HWY_INLINE IVec8 LoadAV1Filter8BitLower(unsigned int offset) {
+template <typename D>
+HWY_ATTR HWY_INLINE hn::VFromD<D> LoadAV1Filter8BitLower(D int8_tag,
+                                                         unsigned int offset) {
   return hn::LoadN(int8_tag, av1_filter_8bit[offset >> WARPEDDIFF_PREC_BITS],
                    8);
 }
 
-template <int Block>
-HWY_ATTR HWY_INLINE IVec8 LoadAV1Filter8BitUpper(unsigned int offset,
-                                                 IVec8 src) {
+template <int Block, typename D>
+HWY_ATTR HWY_INLINE hn::VFromD<D> LoadAV1Filter8BitUpper(D int8_tag,
+                                                         unsigned int offset,
+                                                         hn::VFromD<D> src) {
+  (void)int8_tag;
   return hn::InsertBlock<Block>(
       src, hn::LoadN(int8x16_tag,
                      av1_filter_8bit[offset >> WARPEDDIFF_PREC_BITS], 8));
 }
 
+template <typename D>
 HWY_ATTR inline void PrepareHorizontalFilterCoefficients(
-    int alpha, int beta, int sx, int8_t *HWY_RESTRICT coeff) {
-  auto tmp_0 = LoadAV1Filter8BitLower(sx + 0 * alpha);
-  auto tmp_1 = LoadAV1Filter8BitLower(sx + 1 * alpha);
-  auto tmp_2 = LoadAV1Filter8BitLower(sx + 2 * alpha);
-  auto tmp_3 = LoadAV1Filter8BitLower(sx + 3 * alpha);
-  auto tmp_4 = LoadAV1Filter8BitLower(sx + 4 * alpha);
-  auto tmp_5 = LoadAV1Filter8BitLower(sx + 5 * alpha);
-  auto tmp_6 = LoadAV1Filter8BitLower(sx + 6 * alpha);
-  auto tmp_7 = LoadAV1Filter8BitLower(sx + 7 * alpha);
+    D int16_tag, int alpha, int beta, int sx, int8_t *HWY_RESTRICT coeff) {
+  constexpr auto int8_tag = hn::Repartition<int8_t, D>();
+  constexpr auto int32_tag = hn::Repartition<int32_t, D>();
+  constexpr auto int64_tag = hn::Repartition<int64_t, D>();
+
+  auto tmp_0 = LoadAV1Filter8BitLower(int8_tag, sx + 0 * alpha);
+  auto tmp_1 = LoadAV1Filter8BitLower(int8_tag, sx + 1 * alpha);
+  auto tmp_2 = LoadAV1Filter8BitLower(int8_tag, sx + 2 * alpha);
+  auto tmp_3 = LoadAV1Filter8BitLower(int8_tag, sx + 3 * alpha);
+  auto tmp_4 = LoadAV1Filter8BitLower(int8_tag, sx + 4 * alpha);
+  auto tmp_5 = LoadAV1Filter8BitLower(int8_tag, sx + 5 * alpha);
+  auto tmp_6 = LoadAV1Filter8BitLower(int8_tag, sx + 6 * alpha);
+  auto tmp_7 = LoadAV1Filter8BitLower(int8_tag, sx + 7 * alpha);
 
   if constexpr (int16_tag.MaxBlocks() >= 2) {
-    tmp_0 = LoadAV1Filter8BitUpper<1>(sx + beta + 0 * alpha, tmp_0);
-    tmp_1 = LoadAV1Filter8BitUpper<1>(sx + beta + 1 * alpha, tmp_1);
-    tmp_2 = LoadAV1Filter8BitUpper<1>(sx + beta + 2 * alpha, tmp_2);
-    tmp_3 = LoadAV1Filter8BitUpper<1>(sx + beta + 3 * alpha, tmp_3);
-    tmp_4 = LoadAV1Filter8BitUpper<1>(sx + beta + 4 * alpha, tmp_4);
-    tmp_5 = LoadAV1Filter8BitUpper<1>(sx + beta + 5 * alpha, tmp_5);
-    tmp_6 = LoadAV1Filter8BitUpper<1>(sx + beta + 6 * alpha, tmp_6);
-    tmp_7 = LoadAV1Filter8BitUpper<1>(sx + beta + 7 * alpha, tmp_7);
+    tmp_0 = LoadAV1Filter8BitUpper<1>(int8_tag, sx + beta + 0 * alpha, tmp_0);
+    tmp_1 = LoadAV1Filter8BitUpper<1>(int8_tag, sx + beta + 1 * alpha, tmp_1);
+    tmp_2 = LoadAV1Filter8BitUpper<1>(int8_tag, sx + beta + 2 * alpha, tmp_2);
+    tmp_3 = LoadAV1Filter8BitUpper<1>(int8_tag, sx + beta + 3 * alpha, tmp_3);
+    tmp_4 = LoadAV1Filter8BitUpper<1>(int8_tag, sx + beta + 4 * alpha, tmp_4);
+    tmp_5 = LoadAV1Filter8BitUpper<1>(int8_tag, sx + beta + 5 * alpha, tmp_5);
+    tmp_6 = LoadAV1Filter8BitUpper<1>(int8_tag, sx + beta + 6 * alpha, tmp_6);
+    tmp_7 = LoadAV1Filter8BitUpper<1>(int8_tag, sx + beta + 7 * alpha, tmp_7);
   }
 
   if constexpr (int16_tag.MaxBlocks() >= 3) {
-    tmp_0 = LoadAV1Filter8BitUpper<2>(sx + beta * 2 + 0 * alpha, tmp_0);
-    tmp_1 = LoadAV1Filter8BitUpper<2>(sx + beta * 2 + 1 * alpha, tmp_1);
-    tmp_2 = LoadAV1Filter8BitUpper<2>(sx + beta * 2 + 2 * alpha, tmp_2);
-    tmp_3 = LoadAV1Filter8BitUpper<2>(sx + beta * 2 + 3 * alpha, tmp_3);
-    tmp_4 = LoadAV1Filter8BitUpper<2>(sx + beta * 2 + 4 * alpha, tmp_4);
-    tmp_5 = LoadAV1Filter8BitUpper<2>(sx + beta * 2 + 5 * alpha, tmp_5);
-    tmp_6 = LoadAV1Filter8BitUpper<2>(sx + beta * 2 + 6 * alpha, tmp_6);
-    tmp_7 = LoadAV1Filter8BitUpper<2>(sx + beta * 2 + 7 * alpha, tmp_7);
+    tmp_0 =
+        LoadAV1Filter8BitUpper<2>(int8_tag, sx + beta * 2 + 0 * alpha, tmp_0);
+    tmp_1 =
+        LoadAV1Filter8BitUpper<2>(int8_tag, sx + beta * 2 + 1 * alpha, tmp_1);
+    tmp_2 =
+        LoadAV1Filter8BitUpper<2>(int8_tag, sx + beta * 2 + 2 * alpha, tmp_2);
+    tmp_3 =
+        LoadAV1Filter8BitUpper<2>(int8_tag, sx + beta * 2 + 3 * alpha, tmp_3);
+    tmp_4 =
+        LoadAV1Filter8BitUpper<2>(int8_tag, sx + beta * 2 + 4 * alpha, tmp_4);
+    tmp_5 =
+        LoadAV1Filter8BitUpper<2>(int8_tag, sx + beta * 2 + 5 * alpha, tmp_5);
+    tmp_6 =
+        LoadAV1Filter8BitUpper<2>(int8_tag, sx + beta * 2 + 6 * alpha, tmp_6);
+    tmp_7 =
+        LoadAV1Filter8BitUpper<2>(int8_tag, sx + beta * 2 + 7 * alpha, tmp_7);
 
-    tmp_0 = LoadAV1Filter8BitUpper<3>(sx + beta * 3 + 0 * alpha, tmp_0);
-    tmp_1 = LoadAV1Filter8BitUpper<3>(sx + beta * 3 + 1 * alpha, tmp_1);
-    tmp_2 = LoadAV1Filter8BitUpper<3>(sx + beta * 3 + 2 * alpha, tmp_2);
-    tmp_3 = LoadAV1Filter8BitUpper<3>(sx + beta * 3 + 3 * alpha, tmp_3);
-    tmp_4 = LoadAV1Filter8BitUpper<3>(sx + beta * 3 + 4 * alpha, tmp_4);
-    tmp_5 = LoadAV1Filter8BitUpper<3>(sx + beta * 3 + 5 * alpha, tmp_5);
-    tmp_6 = LoadAV1Filter8BitUpper<3>(sx + beta * 3 + 6 * alpha, tmp_6);
-    tmp_7 = LoadAV1Filter8BitUpper<3>(sx + beta * 3 + 7 * alpha, tmp_7);
+    tmp_0 =
+        LoadAV1Filter8BitUpper<3>(int8_tag, sx + beta * 3 + 0 * alpha, tmp_0);
+    tmp_1 =
+        LoadAV1Filter8BitUpper<3>(int8_tag, sx + beta * 3 + 1 * alpha, tmp_1);
+    tmp_2 =
+        LoadAV1Filter8BitUpper<3>(int8_tag, sx + beta * 3 + 2 * alpha, tmp_2);
+    tmp_3 =
+        LoadAV1Filter8BitUpper<3>(int8_tag, sx + beta * 3 + 3 * alpha, tmp_3);
+    tmp_4 =
+        LoadAV1Filter8BitUpper<3>(int8_tag, sx + beta * 3 + 4 * alpha, tmp_4);
+    tmp_5 =
+        LoadAV1Filter8BitUpper<3>(int8_tag, sx + beta * 3 + 5 * alpha, tmp_5);
+    tmp_6 =
+        LoadAV1Filter8BitUpper<3>(int8_tag, sx + beta * 3 + 6 * alpha, tmp_6);
+    tmp_7 =
+        LoadAV1Filter8BitUpper<3>(int8_tag, sx + beta * 3 + 7 * alpha, tmp_7);
   }
 
   const auto tmp_0_16 = hn::BitCast(int16_tag, tmp_0);
@@ -186,17 +211,19 @@
   const auto res_3 = hn::ZipUpper(int64_tag, tmp_13, tmp_15);
 
   hn::Store(hn::BitCast(int8_tag, hn::InterleaveLower(int64_tag, res_0, res_2)),
-            int8_tag, coeff + hn::MaxLanes(int8_tag) * 0);
+            int8_tag, coeff + hn::MaxLanes(coeff_tag) * 0);
   hn::Store(hn::BitCast(int8_tag, hn::InterleaveUpper(int64_tag, res_0, res_2)),
-            int8_tag, coeff + hn::MaxLanes(int8_tag) * 1);
+            int8_tag, coeff + hn::MaxLanes(coeff_tag) * 1);
   hn::Store(hn::BitCast(int8_tag, hn::InterleaveLower(int64_tag, res_1, res_3)),
-            int8_tag, coeff + hn::MaxLanes(int8_tag) * 2);
+            int8_tag, coeff + hn::MaxLanes(coeff_tag) * 2);
   hn::Store(hn::BitCast(int8_tag, hn::InterleaveUpper(int64_tag, res_1, res_3)),
-            int8_tag, coeff + hn::MaxLanes(int8_tag) * 3);
+            int8_tag, coeff + hn::MaxLanes(coeff_tag) * 3);
 }
 
+template <typename D>
 HWY_ATTR inline void PrepareHorizontalFilterCoefficientsBeta0(
-    int alpha, int beta, int sx, int8_t *HWY_RESTRICT coeff) {
+    D int16_tag, int alpha, int beta, int sx, int8_t *HWY_RESTRICT coeff) {
+  (void)int16_tag;
   (void)beta;
   const auto tmp_0 =
       hn::BitCast(int16x8_tag, LoadAV1Filter8Bit(sx + 0 * alpha));
@@ -220,6 +247,10 @@
   const auto tmp_46 = hn::ZipLower(int32x4_tag, tmp_4, tmp_6);
   const auto tmp_57 = hn::ZipLower(int32x4_tag, tmp_5, tmp_7);
 
+  constexpr auto int8_tag = hn::Repartition<int8_t, D>();
+  constexpr auto int32_tag = hn::Repartition<int32_t, D>();
+  constexpr auto int64_tag = hn::Repartition<int64_t, D>();
+
   const auto broadcast_12 =
       hn::BroadcastBlock<0>(hn::ResizeBitCast(int32_tag, tmp_02));
   const auto broadcast_13 =
@@ -235,36 +266,38 @@
   const auto res_3 = hn::ZipUpper(int64_tag, broadcast_13, broadcast_15);
 
   hn::Store(hn::BitCast(int8_tag, hn::InterleaveLower(int64_tag, res_0, res_2)),
-            int8_tag, coeff + hn::MaxLanes(int8_tag) * 0);
+            int8_tag, coeff + hn::MaxLanes(coeff_tag) * 0);
   hn::Store(hn::BitCast(int8_tag, hn::InterleaveUpper(int64_tag, res_0, res_2)),
-            int8_tag, coeff + hn::MaxLanes(int8_tag) * 1);
+            int8_tag, coeff + hn::MaxLanes(coeff_tag) * 1);
   hn::Store(hn::BitCast(int8_tag, hn::InterleaveLower(int64_tag, res_1, res_3)),
-            int8_tag, coeff + hn::MaxLanes(int8_tag) * 2);
+            int8_tag, coeff + hn::MaxLanes(coeff_tag) * 2);
   hn::Store(hn::BitCast(int8_tag, hn::InterleaveUpper(int64_tag, res_1, res_3)),
-            int8_tag, coeff + hn::MaxLanes(int8_tag) * 3);
+            int8_tag, coeff + hn::MaxLanes(coeff_tag) * 3);
 }
 
+template <typename D>
 HWY_ATTR inline void PrepareHorizontalFilterCoefficientsAlpha0(
-    int alpha, int beta, int sx, int8_t *HWY_RESTRICT coeff) {
+    D int16_tag, int alpha, int beta, int sx, int8_t *HWY_RESTRICT coeff) {
   (void)alpha;
-  auto tmp_0 = LoadAV1Filter8BitLower(sx);
+  constexpr auto int8_tag = hn::Repartition<int8_t, D>();
+  auto tmp_0 = LoadAV1Filter8BitLower(int8_tag, sx);
   if constexpr (int16_tag.MaxBlocks() >= 2) {
-    tmp_0 = LoadAV1Filter8BitUpper<1>(sx + beta, tmp_0);
+    tmp_0 = LoadAV1Filter8BitUpper<1>(int8_tag, sx + beta, tmp_0);
   }
   if constexpr (int16_tag.MaxBlocks() >= 3) {
-    tmp_0 = LoadAV1Filter8BitUpper<2>(sx + beta * 2, tmp_0);
-    tmp_0 = LoadAV1Filter8BitUpper<3>(sx + beta * 3, tmp_0);
+    tmp_0 = LoadAV1Filter8BitUpper<2>(int8_tag, sx + beta * 2, tmp_0);
+    tmp_0 = LoadAV1Filter8BitUpper<3>(int8_tag, sx + beta * 3, tmp_0);
   }
   const auto res_0 = hn::BitCast(int16_tag, tmp_0);
 
   hn::Store(hn::BitCast(int8_tag, hn::Broadcast<0>(res_0)), int8_tag,
-            coeff + hn::MaxLanes(int8_tag) * 0);
+            coeff + hn::MaxLanes(coeff_tag) * 0);
   hn::Store(hn::BitCast(int8_tag, hn::Broadcast<1>(res_0)), int8_tag,
-            coeff + hn::MaxLanes(int8_tag) * 1);
+            coeff + hn::MaxLanes(coeff_tag) * 1);
   hn::Store(hn::BitCast(int8_tag, hn::Broadcast<2>(res_0)), int8_tag,
-            coeff + hn::MaxLanes(int8_tag) * 2);
+            coeff + hn::MaxLanes(coeff_tag) * 2);
   hn::Store(hn::BitCast(int8_tag, hn::Broadcast<3>(res_0)), int8_tag,
-            coeff + hn::MaxLanes(int8_tag) * 3);
+            coeff + hn::MaxLanes(coeff_tag) * 3);
 }
 
 template <typename D>
@@ -273,14 +306,17 @@
                                       int alpha, int beta, int row,
                                       const IVec16 round_const,
                                       const int reduce_bits_horiz) {
-  HWY_ALIGN int8_t coeff[4 * hn::MaxLanes(int8_tag)];
-  PrepareHorizontalFilterCoefficients(alpha, beta, sx, coeff);
+  HWY_ALIGN int8_t coeff[4 * hn::MaxLanes(coeff_tag)];
+  PrepareHorizontalFilterCoefficients(hn::Repartition<int16_t, D>(), alpha,
+                                      beta, sx, coeff);
   FilterPixelsHorizontal(tag, src, horz_out, coeff, round_const,
                          reduce_bits_horiz, row);
 }
 
+template <typename D>
 HWY_ATTR inline void PrepareLastHorizontalFilterCoefficients(
-    int alpha, int beta, int sx, int8_t *HWY_RESTRICT coeff) {
+    D int16_tag, int alpha, int beta, int sx, int8_t *HWY_RESTRICT coeff) {
+  (void)int16_tag;
   (void)beta;
   const auto tmp_0 =
       hn::BitCast(int16x8_tag, LoadAV1Filter8Bit(sx + 0 * alpha));
@@ -314,19 +350,21 @@
   const auto tmp_18 = hn::InterleaveLower(int64x2_tag, tmp_13, tmp_15);
   const auto tmp_19 = hn::InterleaveUpper(int64x2_tag, tmp_13, tmp_15);
 
+  constexpr auto int8_tag = hn::Repartition<int8_t, D>();
+
   const auto tmp_20 = hn::ResizeBitCast(int8_tag, tmp_16);
   const auto tmp_21 = hn::ResizeBitCast(int8_tag, tmp_17);
   const auto tmp_22 = hn::ResizeBitCast(int8_tag, tmp_18);
   const auto tmp_23 = hn::ResizeBitCast(int8_tag, tmp_19);
 
   hn::Store(hn::BroadcastBlock<0>(tmp_20), int8_tag,
-            coeff + hn::MaxLanes(int8_tag) * 0);
+            coeff + hn::MaxLanes(coeff_tag) * 0);
   hn::Store(hn::BroadcastBlock<0>(tmp_21), int8_tag,
-            coeff + hn::MaxLanes(int8_tag) * 1);
+            coeff + hn::MaxLanes(coeff_tag) * 1);
   hn::Store(hn::BroadcastBlock<0>(tmp_22), int8_tag,
-            coeff + hn::MaxLanes(int8_tag) * 2);
+            coeff + hn::MaxLanes(coeff_tag) * 2);
   hn::Store(hn::BroadcastBlock<0>(tmp_23), int8_tag,
-            coeff + hn::MaxLanes(int8_tag) * 3);
+            coeff + hn::MaxLanes(coeff_tag) * 3);
 }
 
 template <typename D>
@@ -374,43 +412,79 @@
   return k;
 }
 
-template <
-    bool InnerCoeffUpdate,
-    void (*PrepareCoeffs)(int alpha, int beta, int sx,
-                          int8_t *HWY_RESTRICT coeffs),
-    void (*LastPrepareCoeffs)(int alpha, int beta, int sx,
-                              int8_t *HWY_RESTRICT coeffs) = PrepareCoeffs>
+enum class HorizontalFilterCoeffs {
+  kAlpha0,
+  kBeta0,
+  kDefault,
+};
+
+template <bool IsLast, HorizontalFilterCoeffs Filter, typename D>
+HWY_ATTR void WarpHorizontalPrepareCoeffs(int alpha, int beta, int sx,
+                                          int8_t *HWY_RESTRICT coeffs) {
+  D int16_tag;
+  switch (Filter) {
+    case HorizontalFilterCoeffs::kAlpha0:
+      PrepareHorizontalFilterCoefficientsAlpha0(int16_tag, alpha, beta, sx,
+                                                coeffs);
+      return;
+    case HorizontalFilterCoeffs::kBeta0:
+      PrepareHorizontalFilterCoefficientsBeta0(int16_tag, alpha, beta, sx,
+                                               coeffs);
+      return;
+    case HorizontalFilterCoeffs::kDefault:
+    default:
+      if (IsLast) {
+        PrepareLastHorizontalFilterCoefficients(int16_tag, alpha, beta, sx,
+                                                coeffs);
+      } else {
+        PrepareHorizontalFilterCoefficients(int16_tag, alpha, beta, sx, coeffs);
+      }
+      return;
+  }
+}
+
+template <bool InnerCoeffUpdate, HorizontalFilterCoeffs Filter>
 HWY_ATTR inline void WarpHorizontalFilterTemplate(
     const uint8_t *HWY_RESTRICT ref, int16_t *HWY_RESTRICT horz_out, int stride,
     int32_t ix4, int32_t iy4, int32_t sx4, int alpha, int beta, int p_height,
     int height, int i, const IVec16 round_const, const int reduce_bits_horiz) {
   int k = -7, iy;
-  HWY_ALIGN int8_t coeff[4 * hn::MaxLanes(int8_tag)];
+  HWY_ALIGN int8_t coeff[4 * hn::MaxLanes(coeff_tag)];
   if constexpr (!InnerCoeffUpdate) {
-    PrepareCoeffs(alpha, beta, sx4, coeff);
+    WarpHorizontalPrepareCoeffs<false, Filter, decltype(int16xN_tag)>(
+        alpha, beta, sx4, coeff);
   }
-  if constexpr (uint8_tag.MaxBlocks() >= 3) {
-    k = WarpHorizontalFilterLoop<(InnerCoeffUpdate ? PrepareCoeffs : nullptr)>(
-        uint8_tag, ref, horz_out, stride, ix4, iy4, sx4, alpha, beta, p_height,
-        height, i, round_const, reduce_bits_horiz, k, coeff);
+  if constexpr (uint8xN_tag.MaxBlocks() >= 3) {
+    k = WarpHorizontalFilterLoop<(
+        InnerCoeffUpdate
+            ? WarpHorizontalPrepareCoeffs<false, Filter, decltype(int16xN_tag)>
+            : nullptr)>(uint8xN_tag, ref, horz_out, stride, ix4, iy4, sx4,
+                        alpha, beta, p_height, height, i, round_const,
+                        reduce_bits_horiz, k, coeff);
   }
-  if constexpr (uint8_tag.MaxBlocks() >= 2) {
-    k = WarpHorizontalFilterLoop<(InnerCoeffUpdate ? PrepareCoeffs : nullptr)>(
-        uint8x32_tag, ref, horz_out, stride, ix4, iy4, sx4, alpha, beta,
-        p_height, height, i, round_const, reduce_bits_horiz, k, coeff);
+  if constexpr (uint8xN_tag.MaxBlocks() >= 2) {
+    k = WarpHorizontalFilterLoop<(
+        InnerCoeffUpdate
+            ? WarpHorizontalPrepareCoeffs<false, Filter, decltype(int16x16_tag)>
+            : nullptr)>(uint8x32_tag, ref, horz_out, stride, ix4, iy4, sx4,
+                        alpha, beta, p_height, height, i, round_const,
+                        reduce_bits_horiz, k, coeff);
   }
-  if constexpr (uint8_tag.MaxBlocks() == 1) {
-    k = WarpHorizontalFilterLoop<(InnerCoeffUpdate ? LastPrepareCoeffs
-                                                   : nullptr)>(
-        uint8x16_tag, ref, horz_out, stride, ix4, iy4, sx4, alpha, beta,
-        p_height, height, i, round_const, reduce_bits_horiz, k, coeff);
+  if constexpr (uint8xN_tag.MaxBlocks() == 1) {
+    k = WarpHorizontalFilterLoop<(
+        InnerCoeffUpdate
+            ? WarpHorizontalPrepareCoeffs<true, Filter, decltype(int16x8_tag)>
+            : nullptr)>(uint8x16_tag, ref, horz_out, stride, ix4, iy4, sx4,
+                        alpha, beta, p_height, height, i, round_const,
+                        reduce_bits_horiz, k, coeff);
   }
   iy = iy4 + k;
   iy = clamp(iy, 0, height - 1);
   const auto src = hn::LoadU(uint8x16_tag, ref + iy * stride + ix4 - 7);
   if constexpr (InnerCoeffUpdate) {
     int sx = sx4 + beta * (k + 4);
-    LastPrepareCoeffs(alpha, beta, sx, coeff);
+    WarpHorizontalPrepareCoeffs<true, Filter, decltype(int16x8_tag)>(
+        alpha, beta, sx, coeff);
   }
   FilterPixelsHorizontal(uint8x16_tag, src, horz_out, coeff, round_const,
                          reduce_bits_horiz, k + 7);
@@ -421,25 +495,25 @@
     const int offset_bits, IVec16 &HWY_RESTRICT res_sub_const,
     IVec16 &HWY_RESTRICT round_bits_const, IVec16 &HWY_RESTRICT wt) {
   res_sub_const =
-      hn::Set(int16_tag, -(1 << (offset_bits - conv_params->round_1)) -
-                             (1 << (offset_bits - conv_params->round_1 - 1)));
-  round_bits_const = hn::Set(int16_tag, ((1 << round_bits) >> 1));
+      hn::Set(int16xN_tag, -(1 << (offset_bits - conv_params->round_1)) -
+                               (1 << (offset_bits - conv_params->round_1 - 1)));
+  round_bits_const = hn::Set(int16xN_tag, ((1 << round_bits) >> 1));
 
   const auto w0 = static_cast<int16_t>(conv_params->fwd_offset);
   const auto w1 = static_cast<int16_t>(conv_params->bck_offset);
-  const auto wt0 = hn::Set(int16_tag, w0);
-  const auto wt1 = hn::Set(int16_tag, w1);
+  const auto wt0 = hn::Set(int16xN_tag, w0);
+  const auto wt1 = hn::Set(int16xN_tag, w1);
   wt = hn::InterleaveLower(wt0, wt1);
 }
 
 HWY_ATTR HWY_INLINE IVec16 LoadAV1WarpedFilter(size_t offset) {
-  return hn::LoadN(int16_tag, av1_warped_filter[offset >> WARPEDDIFF_PREC_BITS],
-                   8);
+  return hn::LoadN(int16xN_tag,
+                   av1_warped_filter[offset >> WARPEDDIFF_PREC_BITS], 8);
 }
 
 HWY_ATTR HWY_INLINE IVec16 LoadAV1WarpedFilterLower(size_t offset) {
   return hn::ResizeBitCast(
-      int16_tag,
+      int16xN_tag,
       hn::Load(int16x8_tag, av1_warped_filter[offset >> WARPEDDIFF_PREC_BITS]));
 }
 
@@ -457,14 +531,14 @@
   auto filt_02 = LoadAV1WarpedFilterLower(sy + 4 * gamma);
   auto filt_03 = LoadAV1WarpedFilterLower(sy + 6 * gamma);
 
-  if constexpr (int16_tag.MaxBlocks() >= 2) {
+  if constexpr (int16xN_tag.MaxBlocks() >= 2) {
     filt_00 = LoadAV1WarpedFilterUpper<1>(sy + delta + 0 * gamma, filt_00);
     filt_01 = LoadAV1WarpedFilterUpper<1>(sy + delta + 2 * gamma, filt_01);
     filt_02 = LoadAV1WarpedFilterUpper<1>(sy + delta + 4 * gamma, filt_02);
     filt_03 = LoadAV1WarpedFilterUpper<1>(sy + delta + 6 * gamma, filt_03);
   }
 
-  if constexpr (int16_tag.MaxBlocks() >= 3) {
+  if constexpr (int16xN_tag.MaxBlocks() >= 3) {
     filt_00 = LoadAV1WarpedFilterUpper<2>(sy + 2 * delta + 0 * gamma, filt_00);
     filt_01 = LoadAV1WarpedFilterUpper<2>(sy + 2 * delta + 2 * gamma, filt_01);
     filt_02 = LoadAV1WarpedFilterUpper<2>(sy + 2 * delta + 4 * gamma, filt_02);
@@ -476,42 +550,42 @@
     filt_03 = LoadAV1WarpedFilterUpper<3>(sy + 3 * delta + 6 * gamma, filt_03);
   }
 
-  auto filt_0 = hn::BitCast(int32_tag, filt_00);
-  auto filt_1 = hn::BitCast(int32_tag, filt_01);
-  auto filt_2 = hn::BitCast(int32_tag, filt_02);
-  auto filt_3 = hn::BitCast(int32_tag, filt_03);
+  auto filt_0 = hn::BitCast(int32xN_tag, filt_00);
+  auto filt_1 = hn::BitCast(int32xN_tag, filt_01);
+  auto filt_2 = hn::BitCast(int32xN_tag, filt_02);
+  auto filt_3 = hn::BitCast(int32xN_tag, filt_03);
 
-  auto res_0 = hn::ZipLower(int64_tag, filt_0, filt_1);
-  auto res_1 = hn::ZipLower(int64_tag, filt_2, filt_3);
-  auto res_2 = hn::ZipUpper(int64_tag, filt_0, filt_1);
-  auto res_3 = hn::ZipUpper(int64_tag, filt_2, filt_3);
+  auto res_0 = hn::ZipLower(int64xN_tag, filt_0, filt_1);
+  auto res_1 = hn::ZipLower(int64xN_tag, filt_2, filt_3);
+  auto res_2 = hn::ZipUpper(int64xN_tag, filt_0, filt_1);
+  auto res_3 = hn::ZipUpper(int64xN_tag, filt_2, filt_3);
 
   hn::Store(
-      hn::BitCast(int16_tag, hn::InterleaveLower(int64_tag, res_0, res_1)),
-      int16_tag, coeffs + 0 * hn::MaxLanes(int16_tag));
+      hn::BitCast(int16xN_tag, hn::InterleaveLower(int64xN_tag, res_0, res_1)),
+      int16xN_tag, coeffs + 0 * hn::MaxLanes(int16xN_tag));
   hn::Store(
-      hn::BitCast(int16_tag, hn::InterleaveUpper(int64_tag, res_0, res_1)),
-      int16_tag, coeffs + 1 * hn::MaxLanes(int16_tag));
+      hn::BitCast(int16xN_tag, hn::InterleaveUpper(int64xN_tag, res_0, res_1)),
+      int16xN_tag, coeffs + 1 * hn::MaxLanes(int16xN_tag));
   hn::Store(
-      hn::BitCast(int16_tag, hn::InterleaveLower(int64_tag, res_2, res_3)),
-      int16_tag, coeffs + 2 * hn::MaxLanes(int16_tag));
+      hn::BitCast(int16xN_tag, hn::InterleaveLower(int64xN_tag, res_2, res_3)),
+      int16xN_tag, coeffs + 2 * hn::MaxLanes(int16xN_tag));
   hn::Store(
-      hn::BitCast(int16_tag, hn::InterleaveUpper(int64_tag, res_2, res_3)),
-      int16_tag, coeffs + 3 * hn::MaxLanes(int16_tag));
+      hn::BitCast(int16xN_tag, hn::InterleaveUpper(int64xN_tag, res_2, res_3)),
+      int16xN_tag, coeffs + 3 * hn::MaxLanes(int16xN_tag));
 
   filt_00 = LoadAV1WarpedFilterLower(sy + 1 * gamma);
   filt_01 = LoadAV1WarpedFilterLower(sy + 3 * gamma);
   filt_02 = LoadAV1WarpedFilterLower(sy + 5 * gamma);
   filt_03 = LoadAV1WarpedFilterLower(sy + 7 * gamma);
 
-  if constexpr (int16_tag.MaxBlocks() >= 2) {
+  if constexpr (int16xN_tag.MaxBlocks() >= 2) {
     filt_00 = LoadAV1WarpedFilterUpper<1>(sy + delta + 1 * gamma, filt_00);
     filt_01 = LoadAV1WarpedFilterUpper<1>(sy + delta + 3 * gamma, filt_01);
     filt_02 = LoadAV1WarpedFilterUpper<1>(sy + delta + 5 * gamma, filt_02);
     filt_03 = LoadAV1WarpedFilterUpper<1>(sy + delta + 7 * gamma, filt_03);
   }
 
-  if constexpr (int16_tag.MaxBlocks() >= 3) {
+  if constexpr (int16xN_tag.MaxBlocks() >= 3) {
     filt_00 = LoadAV1WarpedFilterUpper<2>(sy + 2 * delta + 1 * gamma, filt_00);
     filt_01 = LoadAV1WarpedFilterUpper<2>(sy + 2 * delta + 3 * gamma, filt_01);
     filt_02 = LoadAV1WarpedFilterUpper<2>(sy + 2 * delta + 5 * gamma, filt_02);
@@ -523,28 +597,28 @@
     filt_03 = LoadAV1WarpedFilterUpper<3>(sy + 3 * delta + 7 * gamma, filt_03);
   }
 
-  filt_0 = hn::BitCast(int32_tag, filt_00);
-  filt_1 = hn::BitCast(int32_tag, filt_01);
-  filt_2 = hn::BitCast(int32_tag, filt_02);
-  filt_3 = hn::BitCast(int32_tag, filt_03);
+  filt_0 = hn::BitCast(int32xN_tag, filt_00);
+  filt_1 = hn::BitCast(int32xN_tag, filt_01);
+  filt_2 = hn::BitCast(int32xN_tag, filt_02);
+  filt_3 = hn::BitCast(int32xN_tag, filt_03);
 
-  res_0 = hn::ZipLower(int64_tag, filt_0, filt_1);
-  res_1 = hn::ZipLower(int64_tag, filt_2, filt_3);
-  res_2 = hn::ZipUpper(int64_tag, filt_0, filt_1);
-  res_3 = hn::ZipUpper(int64_tag, filt_2, filt_3);
+  res_0 = hn::ZipLower(int64xN_tag, filt_0, filt_1);
+  res_1 = hn::ZipLower(int64xN_tag, filt_2, filt_3);
+  res_2 = hn::ZipUpper(int64xN_tag, filt_0, filt_1);
+  res_3 = hn::ZipUpper(int64xN_tag, filt_2, filt_3);
 
   hn::Store(
-      hn::BitCast(int16_tag, hn::InterleaveLower(int64_tag, res_0, res_1)),
-      int16_tag, coeffs + 4 * hn::MaxLanes(int16_tag));
+      hn::BitCast(int16xN_tag, hn::InterleaveLower(int64xN_tag, res_0, res_1)),
+      int16xN_tag, coeffs + 4 * hn::MaxLanes(int16xN_tag));
   hn::Store(
-      hn::BitCast(int16_tag, hn::InterleaveUpper(int64_tag, res_0, res_1)),
-      int16_tag, coeffs + 5 * hn::MaxLanes(int16_tag));
+      hn::BitCast(int16xN_tag, hn::InterleaveUpper(int64xN_tag, res_0, res_1)),
+      int16xN_tag, coeffs + 5 * hn::MaxLanes(int16xN_tag));
   hn::Store(
-      hn::BitCast(int16_tag, hn::InterleaveLower(int64_tag, res_2, res_3)),
-      int16_tag, coeffs + 6 * hn::MaxLanes(int16_tag));
+      hn::BitCast(int16xN_tag, hn::InterleaveLower(int64xN_tag, res_2, res_3)),
+      int16xN_tag, coeffs + 6 * hn::MaxLanes(int16xN_tag));
   hn::Store(
-      hn::BitCast(int16_tag, hn::InterleaveUpper(int64_tag, res_2, res_3)),
-      int16_tag, coeffs + 7 * hn::MaxLanes(int16_tag));
+      hn::BitCast(int16xN_tag, hn::InterleaveUpper(int64xN_tag, res_2, res_3)),
+      int16xN_tag, coeffs + 7 * hn::MaxLanes(int16xN_tag));
 }
 
 HWY_ATTR inline void PrepareVerticalFilterCoeffsDelta0(
@@ -555,125 +629,125 @@
   auto filt_02 = LoadAV1WarpedFilter(sy + 4 * gamma);
   auto filt_03 = LoadAV1WarpedFilter(sy + 6 * gamma);
 
-  auto filt_10 = hn::BitCast(int32_tag, hn::BroadcastBlock<0>(filt_00));
-  auto filt_11 = hn::BitCast(int32_tag, hn::BroadcastBlock<0>(filt_01));
-  auto filt_12 = hn::BitCast(int32_tag, hn::BroadcastBlock<0>(filt_02));
-  auto filt_13 = hn::BitCast(int32_tag, hn::BroadcastBlock<0>(filt_03));
+  auto filt_10 = hn::BitCast(int32xN_tag, hn::BroadcastBlock<0>(filt_00));
+  auto filt_11 = hn::BitCast(int32xN_tag, hn::BroadcastBlock<0>(filt_01));
+  auto filt_12 = hn::BitCast(int32xN_tag, hn::BroadcastBlock<0>(filt_02));
+  auto filt_13 = hn::BitCast(int32xN_tag, hn::BroadcastBlock<0>(filt_03));
 
-  auto res_0 = hn::ZipLower(int64_tag, filt_10, filt_11);
-  auto res_1 = hn::ZipLower(int64_tag, filt_12, filt_13);
-  auto res_2 = hn::ZipUpper(int64_tag, filt_10, filt_11);
-  auto res_3 = hn::ZipUpper(int64_tag, filt_12, filt_13);
+  auto res_0 = hn::ZipLower(int64xN_tag, filt_10, filt_11);
+  auto res_1 = hn::ZipLower(int64xN_tag, filt_12, filt_13);
+  auto res_2 = hn::ZipUpper(int64xN_tag, filt_10, filt_11);
+  auto res_3 = hn::ZipUpper(int64xN_tag, filt_12, filt_13);
 
   hn::Store(
-      hn::BitCast(int16_tag, hn::InterleaveLower(int64_tag, res_0, res_1)),
-      int16_tag, coeffs + 0 * hn::MaxLanes(int16_tag));
+      hn::BitCast(int16xN_tag, hn::InterleaveLower(int64xN_tag, res_0, res_1)),
+      int16xN_tag, coeffs + 0 * hn::MaxLanes(int16xN_tag));
   hn::Store(
-      hn::BitCast(int16_tag, hn::InterleaveUpper(int64_tag, res_0, res_1)),
-      int16_tag, coeffs + 1 * hn::MaxLanes(int16_tag));
+      hn::BitCast(int16xN_tag, hn::InterleaveUpper(int64xN_tag, res_0, res_1)),
+      int16xN_tag, coeffs + 1 * hn::MaxLanes(int16xN_tag));
   hn::Store(
-      hn::BitCast(int16_tag, hn::InterleaveLower(int64_tag, res_2, res_3)),
-      int16_tag, coeffs + 2 * hn::MaxLanes(int16_tag));
+      hn::BitCast(int16xN_tag, hn::InterleaveLower(int64xN_tag, res_2, res_3)),
+      int16xN_tag, coeffs + 2 * hn::MaxLanes(int16xN_tag));
   hn::Store(
-      hn::BitCast(int16_tag, hn::InterleaveUpper(int64_tag, res_2, res_3)),
-      int16_tag, coeffs + 3 * hn::MaxLanes(int16_tag));
+      hn::BitCast(int16xN_tag, hn::InterleaveUpper(int64xN_tag, res_2, res_3)),
+      int16xN_tag, coeffs + 3 * hn::MaxLanes(int16xN_tag));
 
   filt_00 = LoadAV1WarpedFilter(sy + 1 * gamma);
   filt_01 = LoadAV1WarpedFilter(sy + 3 * gamma);
   filt_02 = LoadAV1WarpedFilter(sy + 5 * gamma);
   filt_03 = LoadAV1WarpedFilter(sy + 7 * gamma);
 
-  filt_10 = hn::BitCast(int32_tag, hn::BroadcastBlock<0>(filt_00));
-  filt_11 = hn::BitCast(int32_tag, hn::BroadcastBlock<0>(filt_01));
-  filt_12 = hn::BitCast(int32_tag, hn::BroadcastBlock<0>(filt_02));
-  filt_13 = hn::BitCast(int32_tag, hn::BroadcastBlock<0>(filt_03));
+  filt_10 = hn::BitCast(int32xN_tag, hn::BroadcastBlock<0>(filt_00));
+  filt_11 = hn::BitCast(int32xN_tag, hn::BroadcastBlock<0>(filt_01));
+  filt_12 = hn::BitCast(int32xN_tag, hn::BroadcastBlock<0>(filt_02));
+  filt_13 = hn::BitCast(int32xN_tag, hn::BroadcastBlock<0>(filt_03));
 
-  res_0 = hn::ZipLower(int64_tag, filt_10, filt_11);
-  res_1 = hn::ZipLower(int64_tag, filt_12, filt_13);
-  res_2 = hn::ZipUpper(int64_tag, filt_10, filt_11);
-  res_3 = hn::ZipUpper(int64_tag, filt_12, filt_13);
+  res_0 = hn::ZipLower(int64xN_tag, filt_10, filt_11);
+  res_1 = hn::ZipLower(int64xN_tag, filt_12, filt_13);
+  res_2 = hn::ZipUpper(int64xN_tag, filt_10, filt_11);
+  res_3 = hn::ZipUpper(int64xN_tag, filt_12, filt_13);
 
   hn::Store(
-      hn::BitCast(int16_tag, hn::InterleaveLower(int64_tag, res_0, res_1)),
-      int16_tag, coeffs + 4 * hn::MaxLanes(int16_tag));
+      hn::BitCast(int16xN_tag, hn::InterleaveLower(int64xN_tag, res_0, res_1)),
+      int16xN_tag, coeffs + 4 * hn::MaxLanes(int16xN_tag));
   hn::Store(
-      hn::BitCast(int16_tag, hn::InterleaveUpper(int64_tag, res_0, res_1)),
-      int16_tag, coeffs + 5 * hn::MaxLanes(int16_tag));
+      hn::BitCast(int16xN_tag, hn::InterleaveUpper(int64xN_tag, res_0, res_1)),
+      int16xN_tag, coeffs + 5 * hn::MaxLanes(int16xN_tag));
   hn::Store(
-      hn::BitCast(int16_tag, hn::InterleaveLower(int64_tag, res_2, res_3)),
-      int16_tag, coeffs + 6 * hn::MaxLanes(int16_tag));
+      hn::BitCast(int16xN_tag, hn::InterleaveLower(int64xN_tag, res_2, res_3)),
+      int16xN_tag, coeffs + 6 * hn::MaxLanes(int16xN_tag));
   hn::Store(
-      hn::BitCast(int16_tag, hn::InterleaveUpper(int64_tag, res_2, res_3)),
-      int16_tag, coeffs + 7 * hn::MaxLanes(int16_tag));
+      hn::BitCast(int16xN_tag, hn::InterleaveUpper(int64xN_tag, res_2, res_3)),
+      int16xN_tag, coeffs + 7 * hn::MaxLanes(int16xN_tag));
 }
 
 HWY_ATTR inline void PrepareVerticalFilterCoeffsGamma0(
     int gamma, int delta, int sy, int16_t *HWY_RESTRICT coeffs) {
   (void)gamma;
   auto filt_0 = LoadAV1WarpedFilterLower(sy);
-  if constexpr (int16_tag.MaxBlocks() >= 2) {
+  if constexpr (int16xN_tag.MaxBlocks() >= 2) {
     filt_0 = LoadAV1WarpedFilterUpper<1>(sy + delta, filt_0);
   }
-  if constexpr (int16_tag.MaxBlocks() >= 3) {
+  if constexpr (int16xN_tag.MaxBlocks() >= 3) {
     filt_0 = LoadAV1WarpedFilterUpper<2>(sy + 2 * delta, filt_0);
     filt_0 = LoadAV1WarpedFilterUpper<3>(sy + 3 * delta, filt_0);
   }
-  auto res_0 = hn::BitCast(int32_tag, filt_0);
+  auto res_0 = hn::BitCast(int32xN_tag, filt_0);
 
-  auto broadcast_0 = hn::BitCast(int16_tag, hn::Broadcast<0>(res_0));
-  auto broadcast_1 = hn::BitCast(int16_tag, hn::Broadcast<1>(res_0));
-  auto broadcast_2 = hn::BitCast(int16_tag, hn::Broadcast<2>(res_0));
-  auto broadcast_3 = hn::BitCast(int16_tag, hn::Broadcast<3>(res_0));
+  auto broadcast_0 = hn::BitCast(int16xN_tag, hn::Broadcast<0>(res_0));
+  auto broadcast_1 = hn::BitCast(int16xN_tag, hn::Broadcast<1>(res_0));
+  auto broadcast_2 = hn::BitCast(int16xN_tag, hn::Broadcast<2>(res_0));
+  auto broadcast_3 = hn::BitCast(int16xN_tag, hn::Broadcast<3>(res_0));
 
-  hn::Store(broadcast_0, int16_tag, coeffs + 0 * hn::MaxLanes(int16_tag));
-  hn::Store(broadcast_1, int16_tag, coeffs + 1 * hn::MaxLanes(int16_tag));
-  hn::Store(broadcast_2, int16_tag, coeffs + 2 * hn::MaxLanes(int16_tag));
-  hn::Store(broadcast_3, int16_tag, coeffs + 3 * hn::MaxLanes(int16_tag));
-  hn::Store(broadcast_0, int16_tag, coeffs + 4 * hn::MaxLanes(int16_tag));
-  hn::Store(broadcast_1, int16_tag, coeffs + 5 * hn::MaxLanes(int16_tag));
-  hn::Store(broadcast_2, int16_tag, coeffs + 6 * hn::MaxLanes(int16_tag));
-  hn::Store(broadcast_3, int16_tag, coeffs + 7 * hn::MaxLanes(int16_tag));
+  hn::Store(broadcast_0, int16xN_tag, coeffs + 0 * hn::MaxLanes(int16xN_tag));
+  hn::Store(broadcast_1, int16xN_tag, coeffs + 1 * hn::MaxLanes(int16xN_tag));
+  hn::Store(broadcast_2, int16xN_tag, coeffs + 2 * hn::MaxLanes(int16xN_tag));
+  hn::Store(broadcast_3, int16xN_tag, coeffs + 3 * hn::MaxLanes(int16xN_tag));
+  hn::Store(broadcast_0, int16xN_tag, coeffs + 4 * hn::MaxLanes(int16xN_tag));
+  hn::Store(broadcast_1, int16xN_tag, coeffs + 5 * hn::MaxLanes(int16xN_tag));
+  hn::Store(broadcast_2, int16xN_tag, coeffs + 6 * hn::MaxLanes(int16xN_tag));
+  hn::Store(broadcast_3, int16xN_tag, coeffs + 7 * hn::MaxLanes(int16xN_tag));
 }
 
 HWY_ATTR inline void FilterPixelsVertical(
     int16_t *HWY_RESTRICT horz_out, int16_t *HWY_RESTRICT src_lo,
     int16_t *HWY_RESTRICT src_hi, int16_t *HWY_RESTRICT coeffs,
     IVec32 &HWY_RESTRICT res_lo, IVec32 &HWY_RESTRICT res_hi, int row) {
-  if constexpr (int16_tag.MaxBlocks() >= 3) {
+  if constexpr (int16xN_tag.MaxBlocks() >= 3) {
     const auto horz_out_4 =
-        hn::Load(int16_tag, horz_out + (row + 4) * hn::MaxLanes(int16x8_tag));
-    const auto horz_out_5 =
-        hn::LoadU(int16_tag, horz_out + (row + 5) * hn::MaxLanes(int16x8_tag));
-    const auto horz_out_6 =
-        hn::LoadU(int16_tag, horz_out + (row + 6) * hn::MaxLanes(int16x8_tag));
-    const auto horz_out_7 =
-        hn::LoadU(int16_tag, horz_out + (row + 7) * hn::MaxLanes(int16x8_tag));
+        hn::Load(int16xN_tag, horz_out + (row + 4) * hn::MaxLanes(int16x8_tag));
+    const auto horz_out_5 = hn::LoadU(
+        int16xN_tag, horz_out + (row + 5) * hn::MaxLanes(int16x8_tag));
+    const auto horz_out_6 = hn::LoadU(
+        int16xN_tag, horz_out + (row + 6) * hn::MaxLanes(int16x8_tag));
+    const auto horz_out_7 = hn::LoadU(
+        int16xN_tag, horz_out + (row + 7) * hn::MaxLanes(int16x8_tag));
     const auto src_lo_2 =
-        hn::InterleaveLower(int16_tag, horz_out_4, horz_out_5);
+        hn::InterleaveLower(int16xN_tag, horz_out_4, horz_out_5);
     const auto src_hi_2 =
-        hn::InterleaveUpper(int16_tag, horz_out_4, horz_out_5);
+        hn::InterleaveUpper(int16xN_tag, horz_out_4, horz_out_5);
     const auto src_lo_3 =
-        hn::InterleaveLower(int16_tag, horz_out_6, horz_out_7);
+        hn::InterleaveLower(int16xN_tag, horz_out_6, horz_out_7);
     const auto src_hi_3 =
-        hn::InterleaveUpper(int16_tag, horz_out_6, horz_out_7);
-    hn::Store(src_lo_2, int16_tag, src_lo + 2 * hn::MaxLanes(int16_tag));
-    hn::Store(src_hi_2, int16_tag, src_hi + 2 * hn::MaxLanes(int16_tag));
-    hn::Store(src_lo_3, int16_tag, src_lo + 3 * hn::MaxLanes(int16_tag));
-    hn::Store(src_hi_3, int16_tag, src_hi + 3 * hn::MaxLanes(int16_tag));
-  } else if constexpr (int16_tag.MaxBlocks() == 2) {
+        hn::InterleaveUpper(int16xN_tag, horz_out_6, horz_out_7);
+    hn::Store(src_lo_2, int16xN_tag, src_lo + 2 * hn::MaxLanes(int16xN_tag));
+    hn::Store(src_hi_2, int16xN_tag, src_hi + 2 * hn::MaxLanes(int16xN_tag));
+    hn::Store(src_lo_3, int16xN_tag, src_lo + 3 * hn::MaxLanes(int16xN_tag));
+    hn::Store(src_hi_3, int16xN_tag, src_hi + 3 * hn::MaxLanes(int16xN_tag));
+  } else if constexpr (int16xN_tag.MaxBlocks() == 2) {
     const auto horz_out_6 =
-        hn::Load(int16_tag, horz_out + (row + 6) * hn::MaxLanes(int16x8_tag));
+        hn::Load(int16xN_tag, horz_out + (row + 6) * hn::MaxLanes(int16x8_tag));
     const auto horz_out_8 =
-        hn::Load(int16_tag, horz_out + (row + 8) * hn::MaxLanes(int16x8_tag));
+        hn::Load(int16xN_tag, horz_out + (row + 8) * hn::MaxLanes(int16x8_tag));
     const auto horz_out_7 =
-        hn::ConcatLowerUpper(int16_tag, horz_out_8, horz_out_6);
+        hn::ConcatLowerUpper(int16xN_tag, horz_out_8, horz_out_6);
     const auto src_lo_3 =
-        hn::InterleaveLower(int16_tag, horz_out_6, horz_out_7);
+        hn::InterleaveLower(int16xN_tag, horz_out_6, horz_out_7);
     const auto src_hi_3 =
-        hn::InterleaveUpper(int16_tag, horz_out_6, horz_out_7);
-    hn::Store(src_lo_3, int16_tag, src_lo + 3 * hn::MaxLanes(int16_tag));
-    hn::Store(src_hi_3, int16_tag, src_hi + 3 * hn::MaxLanes(int16_tag));
-  } else if constexpr (int16_tag.MaxBlocks() == 1) {
+        hn::InterleaveUpper(int16xN_tag, horz_out_6, horz_out_7);
+    hn::Store(src_lo_3, int16xN_tag, src_lo + 3 * hn::MaxLanes(int16xN_tag));
+    hn::Store(src_hi_3, int16xN_tag, src_hi + 3 * hn::MaxLanes(int16xN_tag));
+  } else if constexpr (int16xN_tag.MaxBlocks() == 1) {
     const auto horz_out_6 =
         hn::Load(int16x8_tag, horz_out + (row + 6) * hn::MaxLanes(int16x8_tag));
     const auto horz_out_7 =
@@ -687,66 +761,66 @@
   }
 
   const auto coeff_0 =
-      hn::Load(int16_tag, coeffs + 0 * hn::MaxLanes(int16_tag));
+      hn::Load(int16xN_tag, coeffs + 0 * hn::MaxLanes(int16xN_tag));
   const auto coeff_1 =
-      hn::Load(int16_tag, coeffs + 1 * hn::MaxLanes(int16_tag));
+      hn::Load(int16xN_tag, coeffs + 1 * hn::MaxLanes(int16xN_tag));
   const auto coeff_2 =
-      hn::Load(int16_tag, coeffs + 2 * hn::MaxLanes(int16_tag));
+      hn::Load(int16xN_tag, coeffs + 2 * hn::MaxLanes(int16xN_tag));
   const auto coeff_3 =
-      hn::Load(int16_tag, coeffs + 3 * hn::MaxLanes(int16_tag));
+      hn::Load(int16xN_tag, coeffs + 3 * hn::MaxLanes(int16xN_tag));
   const auto coeff_4 =
-      hn::Load(int16_tag, coeffs + 4 * hn::MaxLanes(int16_tag));
+      hn::Load(int16xN_tag, coeffs + 4 * hn::MaxLanes(int16xN_tag));
   const auto coeff_5 =
-      hn::Load(int16_tag, coeffs + 5 * hn::MaxLanes(int16_tag));
+      hn::Load(int16xN_tag, coeffs + 5 * hn::MaxLanes(int16xN_tag));
   const auto coeff_6 =
-      hn::Load(int16_tag, coeffs + 6 * hn::MaxLanes(int16_tag));
+      hn::Load(int16xN_tag, coeffs + 6 * hn::MaxLanes(int16xN_tag));
   const auto coeff_7 =
-      hn::Load(int16_tag, coeffs + 7 * hn::MaxLanes(int16_tag));
+      hn::Load(int16xN_tag, coeffs + 7 * hn::MaxLanes(int16xN_tag));
 
   const auto src_lo_0 =
-      hn::Load(int16_tag, src_lo + 0 * hn::MaxLanes(int16_tag));
+      hn::Load(int16xN_tag, src_lo + 0 * hn::MaxLanes(int16xN_tag));
   const auto src_lo_1 =
-      hn::Load(int16_tag, src_lo + 1 * hn::MaxLanes(int16_tag));
+      hn::Load(int16xN_tag, src_lo + 1 * hn::MaxLanes(int16xN_tag));
   const auto src_lo_2 =
-      hn::Load(int16_tag, src_lo + 2 * hn::MaxLanes(int16_tag));
+      hn::Load(int16xN_tag, src_lo + 2 * hn::MaxLanes(int16xN_tag));
   const auto src_lo_3 =
-      hn::Load(int16_tag, src_lo + 3 * hn::MaxLanes(int16_tag));
+      hn::Load(int16xN_tag, src_lo + 3 * hn::MaxLanes(int16xN_tag));
   const auto src_hi_0 =
-      hn::Load(int16_tag, src_hi + 0 * hn::MaxLanes(int16_tag));
+      hn::Load(int16xN_tag, src_hi + 0 * hn::MaxLanes(int16xN_tag));
   const auto src_hi_1 =
-      hn::Load(int16_tag, src_hi + 1 * hn::MaxLanes(int16_tag));
+      hn::Load(int16xN_tag, src_hi + 1 * hn::MaxLanes(int16xN_tag));
   const auto src_hi_2 =
-      hn::Load(int16_tag, src_hi + 2 * hn::MaxLanes(int16_tag));
+      hn::Load(int16xN_tag, src_hi + 2 * hn::MaxLanes(int16xN_tag));
   const auto src_hi_3 =
-      hn::Load(int16_tag, src_hi + 3 * hn::MaxLanes(int16_tag));
+      hn::Load(int16xN_tag, src_hi + 3 * hn::MaxLanes(int16xN_tag));
 
-  auto even_sum0 = hn::Zero(int32_tag);
-  auto even_sum1 = hn::Zero(int32_tag);
-  even_sum0 = hn::ReorderWidenMulAccumulate(int32_tag, src_lo_0, coeff_0,
+  auto even_sum0 = hn::Zero(int32xN_tag);
+  auto even_sum1 = hn::Zero(int32xN_tag);
+  even_sum0 = hn::ReorderWidenMulAccumulate(int32xN_tag, src_lo_0, coeff_0,
                                             even_sum0, even_sum1);
-  even_sum0 = hn::ReorderWidenMulAccumulate(int32_tag, src_lo_1, coeff_1,
+  even_sum0 = hn::ReorderWidenMulAccumulate(int32xN_tag, src_lo_1, coeff_1,
                                             even_sum0, even_sum1);
-  even_sum0 = hn::ReorderWidenMulAccumulate(int32_tag, src_lo_2, coeff_2,
+  even_sum0 = hn::ReorderWidenMulAccumulate(int32xN_tag, src_lo_2, coeff_2,
                                             even_sum0, even_sum1);
-  even_sum0 = hn::ReorderWidenMulAccumulate(int32_tag, src_lo_3, coeff_3,
+  even_sum0 = hn::ReorderWidenMulAccumulate(int32xN_tag, src_lo_3, coeff_3,
                                             even_sum0, even_sum1);
   auto res_even = hn::RearrangeToOddPlusEven(even_sum0, even_sum1);
 
-  auto odd_sum0 = hn::Zero(int32_tag);
-  auto odd_sum1 = hn::Zero(int32_tag);
-  odd_sum0 = hn::ReorderWidenMulAccumulate(int32_tag, src_hi_0, coeff_4,
+  auto odd_sum0 = hn::Zero(int32xN_tag);
+  auto odd_sum1 = hn::Zero(int32xN_tag);
+  odd_sum0 = hn::ReorderWidenMulAccumulate(int32xN_tag, src_hi_0, coeff_4,
                                            odd_sum0, odd_sum1);
-  odd_sum0 = hn::ReorderWidenMulAccumulate(int32_tag, src_hi_1, coeff_5,
+  odd_sum0 = hn::ReorderWidenMulAccumulate(int32xN_tag, src_hi_1, coeff_5,
                                            odd_sum0, odd_sum1);
-  odd_sum0 = hn::ReorderWidenMulAccumulate(int32_tag, src_hi_2, coeff_6,
+  odd_sum0 = hn::ReorderWidenMulAccumulate(int32xN_tag, src_hi_2, coeff_6,
                                            odd_sum0, odd_sum1);
-  odd_sum0 = hn::ReorderWidenMulAccumulate(int32_tag, src_hi_3, coeff_7,
+  odd_sum0 = hn::ReorderWidenMulAccumulate(int32xN_tag, src_hi_3, coeff_7,
                                            odd_sum0, odd_sum1);
   auto res_odd = hn::RearrangeToOddPlusEven(odd_sum0, odd_sum1);
 
   // Rearrange pixels back into the order 0 ... 7
-  res_lo = hn::InterleaveLower(int32_tag, res_even, res_odd);
-  res_hi = hn::InterleaveUpper(int32_tag, res_even, res_odd);
+  res_lo = hn::InterleaveLower(int32xN_tag, res_even, res_odd);
+  res_hi = hn::InterleaveUpper(int32xN_tag, res_even, res_odd);
 }
 
 template <typename DS, typename DR, typename A, typename B, typename C>
@@ -777,7 +851,7 @@
     uint8_t *HWY_RESTRICT pred, ConvolveParams *HWY_RESTRICT conv_params, int i,
     int j, int k, const int reduce_bits_vert, int p_stride, int p_width,
     const int round_bits) {
-  constexpr int kNumRows = uint16_tag.MaxBlocks();
+  constexpr int kNumRows = uint16xN_tag.MaxBlocks();
   if (conv_params->is_compound) {
     uint16_t *HWY_RESTRICT pointers[kNumRows];
     for (int row = 0; row < kNumRows; ++row) {
@@ -788,10 +862,10 @@
     res_lo =
         hn::ShiftRightSame(hn::Add(res_lo, res_add_const), reduce_bits_vert);
 
-    const auto temp_lo_16 = hn::ReorderDemote2To(uint16_tag, res_lo, res_lo);
+    const auto temp_lo_16 = hn::ReorderDemote2To(uint16xN_tag, res_lo, res_lo);
     if (conv_params->do_average) {
       auto p_16 =
-          hn::ResizeBitCast(uint16_tag, hn::Load(uint16x4_tag, pointers[0]));
+          hn::ResizeBitCast(uint16xN_tag, hn::Load(uint16x4_tag, pointers[0]));
       if constexpr (kNumRows >= 2) {
         p_16 = hn::InsertBlock<1>(
             p_16, hn::ResizeBitCast(uint16x8_tag,
@@ -805,25 +879,26 @@
             p_16, hn::ResizeBitCast(uint16x8_tag,
                                     hn::Load(uint16x4_tag, pointers[3])));
       }
-      auto res_lo_16 = hn::Undefined(int16_tag);
+      auto res_lo_16 = hn::Undefined(int16xN_tag);
       if (conv_params->use_dist_wtd_comp_avg) {
         const auto p_16_lo =
-            hn::BitCast(int16_tag, hn::InterleaveLower(p_16, temp_lo_16));
-        const auto wt_res_lo = hn::WidenMulPairwiseAdd(int32_tag, p_16_lo, wt);
+            hn::BitCast(int16xN_tag, hn::InterleaveLower(p_16, temp_lo_16));
+        const auto wt_res_lo =
+            hn::WidenMulPairwiseAdd(int32xN_tag, p_16_lo, wt);
         const auto shifted_32 = hn::ShiftRight<DIST_PRECISION_BITS>(wt_res_lo);
         res_lo_16 = hn::BitCast(
-            int16_tag,
-            hn::ReorderDemote2To(uint16_tag, shifted_32, shifted_32));
+            int16xN_tag,
+            hn::ReorderDemote2To(uint16xN_tag, shifted_32, shifted_32));
       } else {
         res_lo_16 = hn::ShiftRight<1>(
-            hn::BitCast(int16_tag, hn::Add(p_16, temp_lo_16)));
+            hn::BitCast(int16xN_tag, hn::Add(p_16, temp_lo_16)));
       }
       res_lo_16 = hn::Add(res_lo_16, res_sub_const);
       res_lo_16 =
           hn::ShiftRightSame(hn::Add(res_lo_16, round_bits_const), round_bits);
       const auto res_8_lo =
-          hn::ReorderDemote2To(uint8_tag, res_lo_16, res_lo_16);
-      StoreRows(uint8x4_tag, uint8_tag, res_8_lo, p_stride, i + k, j, pred);
+          hn::ReorderDemote2To(uint8xN_tag, res_lo_16, res_lo_16);
+      StoreRows(uint8x4_tag, uint8xN_tag, res_8_lo, p_stride, i + k, j, pred);
     } else {
       hn::Store(
           hn::ResizeBitCast(uint16x4_tag, hn::ExtractBlock<0>(temp_lo_16)),
@@ -850,10 +925,11 @@
       }
       res_hi =
           hn::ShiftRightSame(hn::Add(res_hi, res_add_const), reduce_bits_vert);
-      const auto temp_hi_16 = hn::ReorderDemote2To(uint16_tag, res_hi, res_hi);
+      const auto temp_hi_16 =
+          hn::ReorderDemote2To(uint16xN_tag, res_hi, res_hi);
       if (conv_params->do_average) {
-        auto p4_16 =
-            hn::ResizeBitCast(uint16_tag, hn::Load(uint16x4_tag, pointers4[0]));
+        auto p4_16 = hn::ResizeBitCast(uint16xN_tag,
+                                       hn::Load(uint16x4_tag, pointers4[0]));
         if constexpr (kNumRows >= 2) {
           p4_16 = hn::InsertBlock<1>(
               p4_16, hn::ResizeBitCast(uint16x8_tag,
@@ -868,27 +944,27 @@
                                        hn::Load(uint16x4_tag, pointers4[3])));
         }
 
-        auto res_hi_16 = hn::Undefined(int16_tag);
+        auto res_hi_16 = hn::Undefined(int16xN_tag);
         if (conv_params->use_dist_wtd_comp_avg) {
           const auto p_16_hi =
-              hn::BitCast(int16_tag, hn::InterleaveLower(p4_16, temp_hi_16));
+              hn::BitCast(int16xN_tag, hn::InterleaveLower(p4_16, temp_hi_16));
           const auto wt_res_hi =
-              hn::WidenMulPairwiseAdd(int32_tag, p_16_hi, wt);
+              hn::WidenMulPairwiseAdd(int32xN_tag, p_16_hi, wt);
           const auto shifted_32 =
               hn::ShiftRight<DIST_PRECISION_BITS>(wt_res_hi);
           res_hi_16 = hn::BitCast(
-              int16_tag,
-              hn::ReorderDemote2To(uint16_tag, shifted_32, shifted_32));
+              int16xN_tag,
+              hn::ReorderDemote2To(uint16xN_tag, shifted_32, shifted_32));
         } else {
           res_hi_16 = hn::ShiftRight<1>(
-              hn::BitCast(int16_tag, hn::Add(p4_16, temp_hi_16)));
+              hn::BitCast(int16xN_tag, hn::Add(p4_16, temp_hi_16)));
         }
         res_hi_16 = hn::Add(res_hi_16, res_sub_const);
         res_hi_16 = hn::ShiftRightSame(hn::Add(res_hi_16, round_bits_const),
                                        round_bits);
         const auto res_8_hi =
-            hn::ReorderDemote2To(uint8_tag, res_hi_16, res_hi_16);
-        StoreRows(uint8x4_tag, uint8_tag, res_8_hi, p_stride, i + k, j + 4,
+            hn::ReorderDemote2To(uint8xN_tag, res_hi_16, res_hi_16);
+        StoreRows(uint8x4_tag, uint8xN_tag, res_8_hi, p_stride, i + k, j + 4,
                   pred);
       } else {
         hn::Store(hn::ResizeBitCast(uint16x4_tag, temp_hi_16), uint16x4_tag,
@@ -915,13 +991,14 @@
         hn::ShiftRightSame(hn::Add(res_hi, res_add_const), reduce_bits_vert);
 
     const auto res_16bit =
-        hn::ReorderDemote2To(int16_tag, res_lo_round, res_hi_round);
-    const auto res_8bit = hn::ReorderDemote2To(uint8_tag, res_16bit, res_16bit);
+        hn::ReorderDemote2To(int16xN_tag, res_lo_round, res_hi_round);
+    const auto res_8bit =
+        hn::ReorderDemote2To(uint8xN_tag, res_16bit, res_16bit);
     // Store, blending with 'pred' if needed
     if (p_width == 4) {
-      StoreRows(uint8x4_tag, uint8_tag, res_8bit, p_stride, i + k, j, pred);
+      StoreRows(uint8x4_tag, uint8xN_tag, res_8bit, p_stride, i + k, j, pred);
     } else {
-      StoreRows(uint8x8_tag, uint8_tag, res_8bit, p_stride, i + k, j, pred);
+      StoreRows(uint8x8_tag, uint8xN_tag, res_8bit, p_stride, i + k, j, pred);
     }
   }
 }
@@ -936,86 +1013,86 @@
     const int reduce_bits_vert, const IVec32 res_add_const,
     const int round_bits, const IVec16 res_sub_const,
     const IVec16 round_bits_const, const IVec16 wt) {
-  HWY_ALIGN int16_t src_lo[4 * hn::MaxLanes(int16_tag)];
-  HWY_ALIGN int16_t src_hi[4 * hn::MaxLanes(int16_tag)];
-  if constexpr (int16_tag.MaxBlocks() >= 3) {
+  HWY_ALIGN int16_t src_lo[4 * hn::MaxLanes(int16xN_tag)];
+  HWY_ALIGN int16_t src_hi[4 * hn::MaxLanes(int16xN_tag)];
+  if constexpr (int16xN_tag.MaxBlocks() >= 3) {
     const auto horz_out_0 =
-        hn::Load(int16_tag, horz_out + 0 * hn::MaxLanes(int16x8_tag));
+        hn::Load(int16xN_tag, horz_out + 0 * hn::MaxLanes(int16x8_tag));
     const auto horz_out_1 =
-        hn::LoadU(int16_tag, horz_out + 1 * hn::MaxLanes(int16x8_tag));
+        hn::LoadU(int16xN_tag, horz_out + 1 * hn::MaxLanes(int16x8_tag));
     const auto horz_out_2 =
-        hn::LoadU(int16_tag, horz_out + 2 * hn::MaxLanes(int16x8_tag));
+        hn::LoadU(int16xN_tag, horz_out + 2 * hn::MaxLanes(int16x8_tag));
     const auto horz_out_3 =
-        hn::LoadU(int16_tag, horz_out + 3 * hn::MaxLanes(int16x8_tag));
-    hn::Store(hn::InterleaveLower(int16_tag, horz_out_0, horz_out_1), int16_tag,
-              src_lo + 0 * hn::MaxLanes(int16_tag));
-    hn::Store(hn::InterleaveUpper(int16_tag, horz_out_0, horz_out_1), int16_tag,
-              src_hi + 0 * hn::MaxLanes(int16_tag));
-    hn::Store(hn::InterleaveLower(int16_tag, horz_out_2, horz_out_3), int16_tag,
-              src_lo + 1 * hn::MaxLanes(int16_tag));
-    hn::Store(hn::InterleaveUpper(int16_tag, horz_out_2, horz_out_3), int16_tag,
-              src_hi + 1 * hn::MaxLanes(int16_tag));
-  } else if constexpr (int16_tag.MaxBlocks() == 2) {
+        hn::LoadU(int16xN_tag, horz_out + 3 * hn::MaxLanes(int16x8_tag));
+    hn::Store(hn::InterleaveLower(int16xN_tag, horz_out_0, horz_out_1),
+              int16xN_tag, src_lo + 0 * hn::MaxLanes(int16xN_tag));
+    hn::Store(hn::InterleaveUpper(int16xN_tag, horz_out_0, horz_out_1),
+              int16xN_tag, src_hi + 0 * hn::MaxLanes(int16xN_tag));
+    hn::Store(hn::InterleaveLower(int16xN_tag, horz_out_2, horz_out_3),
+              int16xN_tag, src_lo + 1 * hn::MaxLanes(int16xN_tag));
+    hn::Store(hn::InterleaveUpper(int16xN_tag, horz_out_2, horz_out_3),
+              int16xN_tag, src_hi + 1 * hn::MaxLanes(int16xN_tag));
+  } else if constexpr (int16xN_tag.MaxBlocks() == 2) {
     const auto horz_out_0 =
-        hn::Load(int16_tag, horz_out + 0 * hn::MaxLanes(int16_tag));
+        hn::Load(int16xN_tag, horz_out + 0 * hn::MaxLanes(int16xN_tag));
     const auto horz_out_2 =
-        hn::Load(int16_tag, horz_out + 1 * hn::MaxLanes(int16_tag));
+        hn::Load(int16xN_tag, horz_out + 1 * hn::MaxLanes(int16xN_tag));
     const auto horz_out_4 =
-        hn::Load(int16_tag, horz_out + 2 * hn::MaxLanes(int16_tag));
+        hn::Load(int16xN_tag, horz_out + 2 * hn::MaxLanes(int16xN_tag));
     const auto horz_out_6 =
-        hn::Load(int16_tag, horz_out + 3 * hn::MaxLanes(int16_tag));
+        hn::Load(int16xN_tag, horz_out + 3 * hn::MaxLanes(int16xN_tag));
     const auto horz_out_1 =
-        hn::ConcatLowerUpper(int16_tag, horz_out_2, horz_out_0);
+        hn::ConcatLowerUpper(int16xN_tag, horz_out_2, horz_out_0);
     const auto horz_out_3 =
-        hn::ConcatLowerUpper(int16_tag, horz_out_4, horz_out_2);
+        hn::ConcatLowerUpper(int16xN_tag, horz_out_4, horz_out_2);
     const auto horz_out_5 =
-        hn::ConcatLowerUpper(int16_tag, horz_out_6, horz_out_4);
-    hn::Store(hn::InterleaveLower(int16_tag, horz_out_0, horz_out_1), int16_tag,
-              src_lo + 0 * hn::MaxLanes(int16_tag));
-    hn::Store(hn::InterleaveUpper(int16_tag, horz_out_0, horz_out_1), int16_tag,
-              src_hi + 0 * hn::MaxLanes(int16_tag));
-    hn::Store(hn::InterleaveLower(int16_tag, horz_out_2, horz_out_3), int16_tag,
-              src_lo + 1 * hn::MaxLanes(int16_tag));
-    hn::Store(hn::InterleaveUpper(int16_tag, horz_out_2, horz_out_3), int16_tag,
-              src_hi + 1 * hn::MaxLanes(int16_tag));
-    hn::Store(hn::InterleaveLower(int16_tag, horz_out_4, horz_out_5), int16_tag,
-              src_lo + 2 * hn::MaxLanes(int16_tag));
-    hn::Store(hn::InterleaveUpper(int16_tag, horz_out_4, horz_out_5), int16_tag,
-              src_hi + 2 * hn::MaxLanes(int16_tag));
+        hn::ConcatLowerUpper(int16xN_tag, horz_out_6, horz_out_4);
+    hn::Store(hn::InterleaveLower(int16xN_tag, horz_out_0, horz_out_1),
+              int16xN_tag, src_lo + 0 * hn::MaxLanes(int16xN_tag));
+    hn::Store(hn::InterleaveUpper(int16xN_tag, horz_out_0, horz_out_1),
+              int16xN_tag, src_hi + 0 * hn::MaxLanes(int16xN_tag));
+    hn::Store(hn::InterleaveLower(int16xN_tag, horz_out_2, horz_out_3),
+              int16xN_tag, src_lo + 1 * hn::MaxLanes(int16xN_tag));
+    hn::Store(hn::InterleaveUpper(int16xN_tag, horz_out_2, horz_out_3),
+              int16xN_tag, src_hi + 1 * hn::MaxLanes(int16xN_tag));
+    hn::Store(hn::InterleaveLower(int16xN_tag, horz_out_4, horz_out_5),
+              int16xN_tag, src_lo + 2 * hn::MaxLanes(int16xN_tag));
+    hn::Store(hn::InterleaveUpper(int16xN_tag, horz_out_4, horz_out_5),
+              int16xN_tag, src_hi + 2 * hn::MaxLanes(int16xN_tag));
   } else {
     const auto horz_out_0 =
-        hn::Load(int16_tag, horz_out + 0 * hn::MaxLanes(int16_tag));
+        hn::Load(int16xN_tag, horz_out + 0 * hn::MaxLanes(int16xN_tag));
     const auto horz_out_1 =
-        hn::Load(int16_tag, horz_out + 1 * hn::MaxLanes(int16_tag));
+        hn::Load(int16xN_tag, horz_out + 1 * hn::MaxLanes(int16xN_tag));
     const auto horz_out_2 =
-        hn::Load(int16_tag, horz_out + 2 * hn::MaxLanes(int16_tag));
+        hn::Load(int16xN_tag, horz_out + 2 * hn::MaxLanes(int16xN_tag));
     const auto horz_out_3 =
-        hn::Load(int16_tag, horz_out + 3 * hn::MaxLanes(int16_tag));
+        hn::Load(int16xN_tag, horz_out + 3 * hn::MaxLanes(int16xN_tag));
     const auto horz_out_4 =
-        hn::Load(int16_tag, horz_out + 4 * hn::MaxLanes(int16_tag));
+        hn::Load(int16xN_tag, horz_out + 4 * hn::MaxLanes(int16xN_tag));
     const auto horz_out_5 =
-        hn::Load(int16_tag, horz_out + 5 * hn::MaxLanes(int16_tag));
-    hn::Store(hn::InterleaveLower(int16_tag, horz_out_0, horz_out_1), int16_tag,
-              src_lo + 0 * hn::MaxLanes(int16_tag));
-    hn::Store(hn::InterleaveUpper(int16_tag, horz_out_0, horz_out_1), int16_tag,
-              src_hi + 0 * hn::MaxLanes(int16_tag));
-    hn::Store(hn::InterleaveLower(int16_tag, horz_out_2, horz_out_3), int16_tag,
-              src_lo + 1 * hn::MaxLanes(int16_tag));
-    hn::Store(hn::InterleaveUpper(int16_tag, horz_out_2, horz_out_3), int16_tag,
-              src_hi + 1 * hn::MaxLanes(int16_tag));
-    hn::Store(hn::InterleaveLower(int16_tag, horz_out_4, horz_out_5), int16_tag,
-              src_lo + 2 * hn::MaxLanes(int16_tag));
-    hn::Store(hn::InterleaveUpper(int16_tag, horz_out_4, horz_out_5), int16_tag,
-              src_hi + 2 * hn::MaxLanes(int16_tag));
+        hn::Load(int16xN_tag, horz_out + 5 * hn::MaxLanes(int16xN_tag));
+    hn::Store(hn::InterleaveLower(int16xN_tag, horz_out_0, horz_out_1),
+              int16xN_tag, src_lo + 0 * hn::MaxLanes(int16xN_tag));
+    hn::Store(hn::InterleaveUpper(int16xN_tag, horz_out_0, horz_out_1),
+              int16xN_tag, src_hi + 0 * hn::MaxLanes(int16xN_tag));
+    hn::Store(hn::InterleaveLower(int16xN_tag, horz_out_2, horz_out_3),
+              int16xN_tag, src_lo + 1 * hn::MaxLanes(int16xN_tag));
+    hn::Store(hn::InterleaveUpper(int16xN_tag, horz_out_2, horz_out_3),
+              int16xN_tag, src_hi + 1 * hn::MaxLanes(int16xN_tag));
+    hn::Store(hn::InterleaveLower(int16xN_tag, horz_out_4, horz_out_5),
+              int16xN_tag, src_lo + 2 * hn::MaxLanes(int16xN_tag));
+    hn::Store(hn::InterleaveUpper(int16xN_tag, horz_out_4, horz_out_5),
+              int16xN_tag, src_hi + 2 * hn::MaxLanes(int16xN_tag));
   }
 
-  HWY_ALIGN int16_t coeffs[8 * hn::MaxLanes(int16_tag)];
+  HWY_ALIGN int16_t coeffs[8 * hn::MaxLanes(int16xN_tag)];
   if constexpr (!InnerCoeffUpdate) {
     PrepareCoeffs(gamma, delta, sy4, coeffs);
   }
 
   for (int k = -4; k < AOMMIN(4, p_height - i - 4);
-       k += static_cast<int>(int16_tag.MaxBlocks())) {
+       k += static_cast<int>(int16xN_tag.MaxBlocks())) {
     if constexpr (InnerCoeffUpdate) {
       int sy = sy4 + delta * (k + 4);
       PrepareCoeffs(gamma, delta, sy, coeffs);
@@ -1028,63 +1105,69 @@
                               round_bits_const, pred, conv_params, i, j, k + 4,
                               reduce_bits_vert, p_stride, p_width, round_bits);
 
-    if constexpr (int16_tag.MaxBlocks() >= 3) {
-      hn::Store(hn::Load(int16_tag, src_lo + 2 * hn::MaxLanes(int16_tag)),
-                int16_tag, src_lo + 0 * hn::MaxLanes(int16_tag));
-      hn::Store(hn::Load(int16_tag, src_lo + 3 * hn::MaxLanes(int16_tag)),
-                int16_tag, src_lo + 1 * hn::MaxLanes(int16_tag));
-      hn::Store(hn::Load(int16_tag, src_hi + 2 * hn::MaxLanes(int16_tag)),
-                int16_tag, src_hi + 0 * hn::MaxLanes(int16_tag));
-      hn::Store(hn::Load(int16_tag, src_hi + 3 * hn::MaxLanes(int16_tag)),
-                int16_tag, src_hi + 1 * hn::MaxLanes(int16_tag));
-    } else if constexpr (int16_tag.MaxBlocks() == 2) {
-      hn::Store(hn::Load(int16_tag, src_lo + 1 * hn::MaxLanes(int16_tag)),
-                int16_tag, src_lo + 0 * hn::MaxLanes(int16_tag));
-      hn::Store(hn::Load(int16_tag, src_lo + 2 * hn::MaxLanes(int16_tag)),
-                int16_tag, src_lo + 1 * hn::MaxLanes(int16_tag));
-      hn::Store(hn::Load(int16_tag, src_lo + 3 * hn::MaxLanes(int16_tag)),
-                int16_tag, src_lo + 2 * hn::MaxLanes(int16_tag));
-      hn::Store(hn::Load(int16_tag, src_hi + 1 * hn::MaxLanes(int16_tag)),
-                int16_tag, src_hi + 0 * hn::MaxLanes(int16_tag));
-      hn::Store(hn::Load(int16_tag, src_hi + 2 * hn::MaxLanes(int16_tag)),
-                int16_tag, src_hi + 1 * hn::MaxLanes(int16_tag));
-      hn::Store(hn::Load(int16_tag, src_hi + 3 * hn::MaxLanes(int16_tag)),
-                int16_tag, src_hi + 2 * hn::MaxLanes(int16_tag));
-    } else if constexpr (int16_tag.MaxBlocks() == 1) {
+    if constexpr (int16xN_tag.MaxBlocks() >= 3) {
+      hn::Store(hn::Load(int16xN_tag, src_lo + 2 * hn::MaxLanes(int16xN_tag)),
+                int16xN_tag, src_lo + 0 * hn::MaxLanes(int16xN_tag));
+      hn::Store(hn::Load(int16xN_tag, src_lo + 3 * hn::MaxLanes(int16xN_tag)),
+                int16xN_tag, src_lo + 1 * hn::MaxLanes(int16xN_tag));
+      hn::Store(hn::Load(int16xN_tag, src_hi + 2 * hn::MaxLanes(int16xN_tag)),
+                int16xN_tag, src_hi + 0 * hn::MaxLanes(int16xN_tag));
+      hn::Store(hn::Load(int16xN_tag, src_hi + 3 * hn::MaxLanes(int16xN_tag)),
+                int16xN_tag, src_hi + 1 * hn::MaxLanes(int16xN_tag));
+    } else if constexpr (int16xN_tag.MaxBlocks() == 2) {
+      hn::Store(hn::Load(int16xN_tag, src_lo + 1 * hn::MaxLanes(int16xN_tag)),
+                int16xN_tag, src_lo + 0 * hn::MaxLanes(int16xN_tag));
+      hn::Store(hn::Load(int16xN_tag, src_lo + 2 * hn::MaxLanes(int16xN_tag)),
+                int16xN_tag, src_lo + 1 * hn::MaxLanes(int16xN_tag));
+      hn::Store(hn::Load(int16xN_tag, src_lo + 3 * hn::MaxLanes(int16xN_tag)),
+                int16xN_tag, src_lo + 2 * hn::MaxLanes(int16xN_tag));
+      hn::Store(hn::Load(int16xN_tag, src_hi + 1 * hn::MaxLanes(int16xN_tag)),
+                int16xN_tag, src_hi + 0 * hn::MaxLanes(int16xN_tag));
+      hn::Store(hn::Load(int16xN_tag, src_hi + 2 * hn::MaxLanes(int16xN_tag)),
+                int16xN_tag, src_hi + 1 * hn::MaxLanes(int16xN_tag));
+      hn::Store(hn::Load(int16xN_tag, src_hi + 3 * hn::MaxLanes(int16xN_tag)),
+                int16xN_tag, src_hi + 2 * hn::MaxLanes(int16xN_tag));
+    } else if constexpr (int16xN_tag.MaxBlocks() == 1) {
       const auto src_lo_0 =
-          hn::Load(int16_tag, src_lo + 0 * hn::MaxLanes(int16_tag));
+          hn::Load(int16xN_tag, src_lo + 0 * hn::MaxLanes(int16xN_tag));
       const auto src_lo_1 =
-          hn::Load(int16_tag, src_lo + 1 * hn::MaxLanes(int16_tag));
+          hn::Load(int16xN_tag, src_lo + 1 * hn::MaxLanes(int16xN_tag));
       const auto src_lo_2 =
-          hn::Load(int16_tag, src_lo + 2 * hn::MaxLanes(int16_tag));
+          hn::Load(int16xN_tag, src_lo + 2 * hn::MaxLanes(int16xN_tag));
       const auto src_lo_3 =
-          hn::Load(int16_tag, src_lo + 3 * hn::MaxLanes(int16_tag));
+          hn::Load(int16xN_tag, src_lo + 3 * hn::MaxLanes(int16xN_tag));
       const auto src_lo_0_new = hn::InterleaveEven(
-          hn::ShiftRightLanes<1>(int16_tag, src_lo_0), src_lo_1);
+          hn::ShiftRightLanes<1>(int16xN_tag, src_lo_0), src_lo_1);
       const auto src_lo_1_new = hn::InterleaveEven(
-          hn::ShiftRightLanes<1>(int16_tag, src_lo_1), src_lo_2);
+          hn::ShiftRightLanes<1>(int16xN_tag, src_lo_1), src_lo_2);
       const auto src_lo_2_new = hn::InterleaveEven(
-          hn::ShiftRightLanes<1>(int16_tag, src_lo_2), src_lo_3);
-      hn::Store(src_lo_0_new, int16_tag, src_lo + 0 * hn::MaxLanes(int16_tag));
-      hn::Store(src_lo_1_new, int16_tag, src_lo + 1 * hn::MaxLanes(int16_tag));
-      hn::Store(src_lo_2_new, int16_tag, src_lo + 2 * hn::MaxLanes(int16_tag));
+          hn::ShiftRightLanes<1>(int16xN_tag, src_lo_2), src_lo_3);
+      hn::Store(src_lo_0_new, int16xN_tag,
+                src_lo + 0 * hn::MaxLanes(int16xN_tag));
+      hn::Store(src_lo_1_new, int16xN_tag,
+                src_lo + 1 * hn::MaxLanes(int16xN_tag));
+      hn::Store(src_lo_2_new, int16xN_tag,
+                src_lo + 2 * hn::MaxLanes(int16xN_tag));
       const auto src_hi_0 =
-          hn::Load(int16_tag, src_hi + 0 * hn::MaxLanes(int16_tag));
+          hn::Load(int16xN_tag, src_hi + 0 * hn::MaxLanes(int16xN_tag));
       const auto src_hi_1 =
-          hn::Load(int16_tag, src_hi + 1 * hn::MaxLanes(int16_tag));
+          hn::Load(int16xN_tag, src_hi + 1 * hn::MaxLanes(int16xN_tag));
       const auto src_hi_2 =
-          hn::Load(int16_tag, src_hi + 2 * hn::MaxLanes(int16_tag));
+          hn::Load(int16xN_tag, src_hi + 2 * hn::MaxLanes(int16xN_tag));
       const auto src_hi_3 =
-          hn::Load(int16_tag, src_hi + 3 * hn::MaxLanes(int16_tag));
+          hn::Load(int16xN_tag, src_hi + 3 * hn::MaxLanes(int16xN_tag));
       const auto src_hi_0_new = hn::InterleaveEven(
-          hn::ShiftRightLanes<1>(int16_tag, src_hi_0), src_hi_1);
+          hn::ShiftRightLanes<1>(int16xN_tag, src_hi_0), src_hi_1);
       const auto src_hi_1_new = hn::InterleaveEven(
-          hn::ShiftRightLanes<1>(int16_tag, src_hi_1), src_hi_2);
+          hn::ShiftRightLanes<1>(int16xN_tag, src_hi_1), src_hi_2);
       const auto src_hi_2_new = hn::InterleaveEven(
-          hn::ShiftRightLanes<1>(int16_tag, src_hi_2), src_hi_3);
-      hn::Store(src_hi_0_new, int16_tag, src_hi + 0 * hn::MaxLanes(int16_tag));
-      hn::Store(src_hi_1_new, int16_tag, src_hi + 1 * hn::MaxLanes(int16_tag));
-      hn::Store(src_hi_2_new, int16_tag, src_hi + 2 * hn::MaxLanes(int16_tag));
+          hn::ShiftRightLanes<1>(int16xN_tag, src_hi_2), src_hi_3);
+      hn::Store(src_hi_0_new, int16xN_tag,
+                src_hi + 0 * hn::MaxLanes(int16xN_tag));
+      hn::Store(src_hi_1_new, int16xN_tag,
+                src_hi + 1 * hn::MaxLanes(int16xN_tag));
+      hn::Store(src_hi_2_new, int16xN_tag,
+                src_hi + 2 * hn::MaxLanes(int16xN_tag));
     }
   }
 }
@@ -1123,23 +1206,19 @@
     int32_t ix4, int32_t iy4, int32_t sx4, int alpha, int beta, int p_height,
     int height, int i, const IVec16 round_const, const int reduce_bits_horiz) {
   if (alpha == 0 && beta == 0)
-    WarpHorizontalFilterTemplate<false,
-                                 PrepareHorizontalFilterCoefficientsAlpha0>(
+    WarpHorizontalFilterTemplate<false, HorizontalFilterCoeffs::kAlpha0>(
         ref, horz_out, stride, ix4, iy4, sx4, alpha, beta, p_height, height, i,
         round_const, reduce_bits_horiz);
   else if (alpha == 0 && beta != 0)
-    WarpHorizontalFilterTemplate<true,
-                                 PrepareHorizontalFilterCoefficientsAlpha0>(
+    WarpHorizontalFilterTemplate<true, HorizontalFilterCoeffs::kAlpha0>(
         ref, horz_out, stride, ix4, iy4, sx4, alpha, beta, p_height, height, i,
         round_const, reduce_bits_horiz);
   else if (alpha != 0 && beta == 0)
-    WarpHorizontalFilterTemplate<false,
-                                 PrepareHorizontalFilterCoefficientsBeta0>(
+    WarpHorizontalFilterTemplate<false, HorizontalFilterCoeffs::kBeta0>(
         ref, horz_out, stride, ix4, iy4, sx4, alpha, beta, p_height, height, i,
         round_const, reduce_bits_horiz);
   else
-    WarpHorizontalFilterTemplate<true, PrepareHorizontalFilterCoefficients,
-                                 PrepareLastHorizontalFilterCoefficients>(
+    WarpHorizontalFilterTemplate<true, HorizontalFilterCoeffs::kDefault>(
         ref, horz_out, stride, ix4, iy4, sx4, alpha, beta, p_height, height, i,
         round_const, reduce_bits_horiz);
 }
@@ -1180,17 +1259,17 @@
     int i, int iy4, int16_t const4, int16_t const5, int offset,
     int16_t *HWY_RESTRICT horz_out) {
   int k = -7, iy;
-  if constexpr (int16_tag.MaxBlocks() >= 3) {
-    k = WarpHorizontalFilterOutOfBoundsSetLoop(int16_tag, ref, height, stride,
+  if constexpr (int16xN_tag.MaxBlocks() >= 3) {
+    k = WarpHorizontalFilterOutOfBoundsSetLoop(int16xN_tag, ref, height, stride,
                                                p_height, i, iy4, const4, const5,
                                                offset, k, horz_out);
   }
-  if constexpr (int16_tag.MaxBlocks() >= 2) {
+  if constexpr (int16xN_tag.MaxBlocks() >= 2) {
     k = WarpHorizontalFilterOutOfBoundsSetLoop(int16x16_tag, ref, height,
                                                stride, p_height, i, iy4, const4,
                                                const5, offset, k, horz_out);
   }
-  if constexpr (int16_tag.MaxBlocks() == 1) {
+  if constexpr (int16xN_tag.MaxBlocks() == 1) {
     k = WarpHorizontalFilterOutOfBoundsSetLoop(int16x8_tag, ref, height, stride,
                                                p_height, i, iy4, const4, const5,
                                                offset, k, horz_out);
@@ -1236,22 +1315,22 @@
   const int out_of_boundary_left = -(ix4 - 6);
   const int out_of_boundary_right = (ix4 + 8) - width;
   int k = -7, iy, sx;
-  if constexpr (uint8_tag.MaxBlocks() >= 3) {
+  if constexpr (uint8xN_tag.MaxBlocks() >= 3) {
     k = WarpHorizontalFilterOutOfBoundsPadLoop(
-        uint8_tag, ref, stride, ix4, iy4, sx4, alpha, beta, p_height, height, i,
-        round_const, reduce_bits_horiz, out_of_boundary_left,
+        uint8xN_tag, ref, stride, ix4, iy4, sx4, alpha, beta, p_height, height,
+        i, round_const, reduce_bits_horiz, out_of_boundary_left,
         out_of_boundary_right, k, horz_out);
   }
-  if constexpr (uint8_tag.MaxBlocks() >= 2) {
+  if constexpr (uint8xN_tag.MaxBlocks() >= 2) {
     k = WarpHorizontalFilterOutOfBoundsPadLoop(
         uint8x32_tag, ref, stride, ix4, iy4, sx4, alpha, beta, p_height, height,
         i, round_const, reduce_bits_horiz, out_of_boundary_left,
         out_of_boundary_right, k, horz_out);
   }
-  if constexpr (uint8_tag.MaxBlocks() == 1) {
+  if constexpr (uint8xN_tag.MaxBlocks() == 1) {
     k = WarpHorizontalFilterOutOfBoundsPadLoop(
-        uint8_tag, ref, stride, ix4, iy4, sx4, alpha, beta, p_height, height, i,
-        round_const, reduce_bits_horiz, out_of_boundary_left,
+        uint8xN_tag, ref, stride, ix4, iy4, sx4, alpha, beta, p_height, height,
+        i, round_const, reduce_bits_horiz, out_of_boundary_left,
         out_of_boundary_right, k, horz_out);
   }
   iy = iy4 + k;
@@ -1268,8 +1347,8 @@
     src = hn::TableLookupBytes(src, shuffle_reg_right);
   }
   sx = sx4 + beta * (k + 4);
-  HWY_ALIGN int8_t coeff[4 * hn::MaxLanes(int8_tag)];
-  PrepareLastHorizontalFilterCoefficients(alpha, beta, sx, coeff);
+  HWY_ALIGN int8_t coeff[4 * hn::MaxLanes(coeff_tag)];
+  PrepareLastHorizontalFilterCoefficients(int16xN_tag, alpha, beta, sx, coeff);
   FilterPixelsHorizontal(uint8x16_tag, src, horz_out, coeff, round_const,
                          reduce_bits_horiz, k + 7);
 }
@@ -1293,15 +1372,15 @@
 
   const int offset_bits_vert = bd + 2 * FILTER_BITS - reduce_bits_horiz;
   const auto reduce_bits_vert_const =
-      hn::Set(int32_tag, ((1 << reduce_bits_vert) >> 1));
-  const auto res_add_const = hn::Set(int32_tag, 1 << offset_bits_vert);
+      hn::Set(int32xN_tag, ((1 << reduce_bits_vert) >> 1));
+  const auto res_add_const = hn::Set(int32xN_tag, 1 << offset_bits_vert);
   const int round_bits =
       2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
   const int offset_bits = bd + 2 * FILTER_BITS - conv_params->round_0;
   assert(IMPLIES(conv_params->do_average, conv_params->is_compound));
 
   const auto round_const = hn::Set(
-      int16_tag, (1 << offset_bits_horiz) + ((1 << reduce_bits_horiz) >> 1));
+      int16xN_tag, (1 << offset_bits_horiz) + ((1 << reduce_bits_horiz) >> 1));
 
   IVec16 res_sub_const, round_bits_const, wt;
   UnpackWeightsAndSetRoundConst(conv_params, round_bits, offset_bits,
@@ -1311,8 +1390,8 @@
   if (conv_params->is_compound == 1) {
     res_add_const_1 = hn::Add(reduce_bits_vert_const, res_add_const);
   } else {
-    res_add_const_1 = hn::Set(int32_tag, -(1 << (bd + reduce_bits_vert - 1)) +
-                                             ((1 << reduce_bits_vert) >> 1));
+    res_add_const_1 = hn::Set(int32xN_tag, -(1 << (bd + reduce_bits_vert - 1)) +
+                                               ((1 << reduce_bits_vert) >> 1));
   }
   const int32_t const1 = alpha * (-4) + beta * (-4) +
                          (1 << (WARPEDDIFF_PREC_BITS - 1)) +
@@ -1326,7 +1405,7 @@
 
   for (i = 0; i < p_height; i += 8) {
     for (j = 0; j < p_width; j += 8) {
-      HWY_ALIGN int16_t horz_out[8 * 16 + hn::MaxLanes(int16_tag)];
+      HWY_ALIGN int16_t horz_out[8 * 16 + hn::MaxLanes(int16xN_tag)];
       const int32_t src_x = (p_col + j + 4) << subsampling_x;
       const int32_t src_y = (p_row + i + 4) << subsampling_y;
       const int64_t dst_x =