K-means: NEON implementation.

Bug: b/217282899

Change-Id: I03520263f78c520c75f4d068031ebd4335476f90
diff --git a/aom_dsp/arm/sum_neon.h b/aom_dsp/arm/sum_neon.h
index dc0ea9f..9101979 100644
--- a/aom_dsp/arm/sum_neon.h
+++ b/aom_dsp/arm/sum_neon.h
@@ -37,6 +37,14 @@
 #endif
 }
 
+static INLINE int64_t horizontal_add_s64x2(const int64x2_t a) {
+#if defined(__aarch64__)
+  return vaddvq_s64(a);
+#else
+  return vgetq_lane_s64(a, 0) + vgetq_lane_s64(a, 1);
+#endif
+}
+
 static INLINE uint64_t horizontal_add_u64x2(const uint64x2_t a) {
 #if defined(__aarch64__)
   return vaddvq_u64(a);
diff --git a/av1/av1.cmake b/av1/av1.cmake
index 992b5c3..1d39c1a 100644
--- a/av1/av1.cmake
+++ b/av1/av1.cmake
@@ -358,6 +358,7 @@
             "${AOM_ROOT}/av1/encoder/arm/neon/av1_error_neon.c"
             "${AOM_ROOT}/av1/encoder/arm/neon/encodetxb_neon.c"
             "${AOM_ROOT}/av1/encoder/arm/neon/hybrid_fwd_txfm_neon.c"
+            "${AOM_ROOT}/av1/encoder/arm/neon/av1_k_means_neon.c"
             "${AOM_ROOT}/av1/encoder/arm/neon/av1_fwd_txfm2d_neon.c"
             "${AOM_ROOT}/av1/encoder/arm/neon/highbd_fwd_txfm_neon.c"
             "${AOM_ROOT}/av1/encoder/arm/neon/wedge_utils_neon.c"
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index 73003a9..176c6f7 100644
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -413,10 +413,10 @@
   add_proto qw/void av1_quantize_b/, "const tran_low_t *coeff_ptr, intptr_t n_coeffs, const int16_t *zbin_ptr, const int16_t *round_ptr, const int16_t *quant_ptr, const int16_t *quant_shift_ptr, tran_low_t *qcoeff_ptr, tran_low_t *dqcoeff_ptr, const int16_t *dequant_ptr, uint16_t *eob_ptr, const int16_t *scan, const int16_t *iscan, const qm_val_t * qm_ptr, const qm_val_t * iqm_ptr, int log_scale";
 
   add_proto qw/void av1_calc_indices_dim1/, "const int16_t *data, const int16_t *centroids, uint8_t *indices, int64_t *total_dist, int n, int k";
-  specialize qw/av1_calc_indices_dim1 sse2 avx2/;
+  specialize qw/av1_calc_indices_dim1 sse2 avx2 neon/;
 
   add_proto qw/void av1_calc_indices_dim2/, "const int16_t *data, const int16_t *centroids, uint8_t *indices, int64_t *total_dist, int n, int k";
-  specialize qw/av1_calc_indices_dim2 sse2 avx2/;
+  specialize qw/av1_calc_indices_dim2 sse2 avx2 neon/;
 
   # ENCODEMB INVOKE
   if (aom_config("CONFIG_AV1_HIGHBITDEPTH") eq "yes") {
diff --git a/av1/encoder/arm/neon/av1_k_means_neon.c b/av1/encoder/arm/neon/av1_k_means_neon.c
new file mode 100644
index 0000000..d421f76
--- /dev/null
+++ b/av1/encoder/arm/neon/av1_k_means_neon.c
@@ -0,0 +1,114 @@
+/*
+ *  Copyright (c) 2023, Alliance for Open Media. All Rights Reserved.
+ *
+ *  Use of this source code is governed by a BSD-style license
+ *  that can be found in the LICENSE file in the root of the source
+ *  tree. An additional intellectual property rights grant can be found
+ *  in the file PATENTS.  All contributing project authors may
+ *  be found in the AUTHORS file in the root of the source tree.
+ */
+
+#include <arm_neon.h>
+
+#include "aom_dsp/arm/sum_neon.h"
+#include "config/aom_dsp_rtcd.h"
+
+static int32x4_t k_means_multiply_add_neon(const int16x8_t a) {
+  const int32x4_t l = vmull_s16(vget_low_s16(a), vget_low_s16(a));
+  const int32x4_t h = vmull_s16(vget_high_s16(a), vget_high_s16(a));
+#if defined(__aarch64__)
+  return vpaddq_s32(l, h);
+#else
+  const int32x2_t dl = vpadd_s32(vget_low_s32(l), vget_high_s32(l));
+  const int32x2_t dh = vpadd_s32(vget_low_s32(h), vget_high_s32(h));
+  return vcombine_s32(dl, dh);
+#endif
+}
+
+void av1_calc_indices_dim1_neon(const int16_t *data, const int16_t *centroids,
+                                uint8_t *indices, int64_t *total_dist, int n,
+                                int k) {
+  int64x2_t sum = vdupq_n_s64(0);
+  int16x8_t cents[PALETTE_MAX_SIZE];
+  for (int j = 0; j < k; ++j) {
+    cents[j] = vdupq_n_s16(centroids[j]);
+  }
+
+  for (int i = 0; i < n; i += 8) {
+    const int16x8_t in = vld1q_s16(data);
+    uint16x8_t ind = vdupq_n_u16(0);
+    // Compute the distance to the first centroid.
+    int16x8_t dist_min = vabdq_s16(in, cents[0]);
+
+    for (int j = 1; j < k; ++j) {
+      // Compute the distance to the centroid.
+      const int16x8_t dist = vabdq_s16(in, cents[j]);
+      // Compare to the minimal one.
+      const uint16x8_t cmp = vcgtq_s16(dist_min, dist);
+      dist_min = vminq_s16(dist_min, dist);
+      const uint16x8_t ind1 = vdupq_n_u16(j);
+      ind = vbslq_u16(cmp, ind1, ind);
+    }
+    if (total_dist) {
+      // Square, convert to 32 bit and add together.
+      const int32x4_t l =
+          vmull_s16(vget_low_s16(dist_min), vget_low_s16(dist_min));
+      const int32x4_t sum32_tmp =
+          vmlal_s16(l, vget_high_s16(dist_min), vget_high_s16(dist_min));
+      // Pairwise sum, convert to 64 bit and add to sum.
+      sum = vpadalq_s32(sum, sum32_tmp);
+    }
+    vst1_u8(indices, vmovn_u16(ind));
+    indices += 8;
+    data += 8;
+  }
+  if (total_dist) {
+    *total_dist = horizontal_add_s64x2(sum);
+  }
+}
+
+void av1_calc_indices_dim2_neon(const int16_t *data, const int16_t *centroids,
+                                uint8_t *indices, int64_t *total_dist, int n,
+                                int k) {
+  int64x2_t sum = vdupq_n_s64(0);
+  uint32x4_t ind[2];
+  int16x8_t cents[PALETTE_MAX_SIZE];
+  for (int j = 0; j < k; ++j) {
+    const int16_t cx = centroids[2 * j], cy = centroids[2 * j + 1];
+    const int16_t cxcy[8] = { cx, cy, cx, cy, cx, cy, cx, cy };
+    cents[j] = vld1q_s16(cxcy);
+  }
+
+  for (int i = 0; i < n; i += 8) {
+    for (int l = 0; l < 2; ++l) {
+      const int16x8_t in = vld1q_s16(data);
+      ind[l] = vdupq_n_u32(0);
+      // Compute the distance to the first centroid.
+      int16x8_t d1 = vsubq_s16(in, cents[0]);
+      int32x4_t dist_min = k_means_multiply_add_neon(d1);
+
+      for (int j = 1; j < k; ++j) {
+        // Compute the distance to the centroid.
+        d1 = vsubq_s16(in, cents[j]);
+        const int32x4_t dist = k_means_multiply_add_neon(d1);
+        // Compare to the minimal one.
+        const uint32x4_t cmp = vcgtq_s32(dist_min, dist);
+        dist_min = vminq_s32(dist_min, dist);
+        const uint32x4_t ind1 = vdupq_n_u32(j);
+        ind[l] = vbslq_u32(cmp, ind1, ind[l]);
+      }
+      if (total_dist) {
+        // Pairwise sum, convert to 64 bit and add to sum.
+        sum = vpadalq_s32(sum, dist_min);
+      }
+      data += 8;
+    }
+    // Cast to 8 bit and store.
+    vst1_u8(indices,
+            vmovn_u16(vcombine_u16(vmovn_u32(ind[0]), vmovn_u32(ind[1]))));
+    indices += 8;
+  }
+  if (total_dist) {
+    *total_dist = horizontal_add_s64x2(sum);
+  }
+}
diff --git a/test/av1_k_means_test.cc b/test/av1_k_means_test.cc
index 221dd10..99f0fba 100644
--- a/test/av1_k_means_test.cc
+++ b/test/av1_k_means_test.cc
@@ -259,7 +259,7 @@
   RunSpeedTest(GET_PARAM(0), GET_PARAM(1), 8);
 }
 
-#if HAVE_AVX2 || HAVE_SSE2
+#if HAVE_SSE2 || HAVE_AVX2 || HAVE_NEON
 const BLOCK_SIZE kValidBlockSize[] = { BLOCK_8X8,   BLOCK_8X16,  BLOCK_8X32,
                                        BLOCK_16X8,  BLOCK_16X16, BLOCK_16X32,
                                        BLOCK_32X8,  BLOCK_32X16, BLOCK_32X32,
@@ -267,6 +267,17 @@
                                        BLOCK_16X64, BLOCK_64X16 };
 #endif
 
+#if HAVE_SSE2
+INSTANTIATE_TEST_SUITE_P(
+    SSE2, AV1KmeansTest1,
+    ::testing::Combine(::testing::Values(&av1_calc_indices_dim1_sse2),
+                       ::testing::ValuesIn(kValidBlockSize)));
+INSTANTIATE_TEST_SUITE_P(
+    SSE2, AV1KmeansTest2,
+    ::testing::Combine(::testing::Values(&av1_calc_indices_dim2_sse2),
+                       ::testing::ValuesIn(kValidBlockSize)));
+#endif
+
 #if HAVE_AVX2
 INSTANTIATE_TEST_SUITE_P(
     AVX2, AV1KmeansTest1,
@@ -278,15 +289,14 @@
                        ::testing::ValuesIn(kValidBlockSize)));
 #endif
 
-#if HAVE_SSE2
-
+#if HAVE_NEON
 INSTANTIATE_TEST_SUITE_P(
-    SSE2, AV1KmeansTest1,
-    ::testing::Combine(::testing::Values(&av1_calc_indices_dim1_sse2),
+    NEON, AV1KmeansTest1,
+    ::testing::Combine(::testing::Values(&av1_calc_indices_dim1_neon),
                        ::testing::ValuesIn(kValidBlockSize)));
 INSTANTIATE_TEST_SUITE_P(
-    SSE2, AV1KmeansTest2,
-    ::testing::Combine(::testing::Values(&av1_calc_indices_dim2_sse2),
+    NEON, AV1KmeansTest2,
+    ::testing::Combine(::testing::Values(&av1_calc_indices_dim2_neon),
                        ::testing::ValuesIn(kValidBlockSize)));
 #endif