Add av1_calc_indices_dim1_avx2 for k_means

PERF analysis: C-code:1.49%  AVX2-code:0.28%

Speedtest:

BlockSize	Gain for 8 Centroids
8X8		4.47x
8X16		4.21x
8X32		4.65x
16X8		4.27x
16X16		4.62x
16X32		5.8x
32X8		4.67x
32X16		5.84x
32X32		7.64x
32X64		8.7x
64X32		8.61x
64X64		9.18x
16X64		7.62x
64X16		7.59x

Change-Id: I417bb9c9b676e4e5b71dbb718133a368c76d30c6
diff --git a/av1/av1.cmake b/av1/av1.cmake
index ba450fb..9187d20 100644
--- a/av1/av1.cmake
+++ b/av1/av1.cmake
@@ -415,6 +415,7 @@
             "${AOM_ROOT}/av1/encoder/x86/wedge_utils_avx2.c"
             "${AOM_ROOT}/av1/encoder/x86/encodetxb_avx2.c"
             "${AOM_ROOT}/av1/encoder/x86/rdopt_avx2.c"
+            "${AOM_ROOT}/av1/encoder/x86/av1_k_means_avx2.c"
             "${AOM_ROOT}/av1/encoder/x86/temporal_filter_avx2.c"
             "${AOM_ROOT}/av1/encoder/x86/highbd_temporal_filter_avx2.c"
             "${AOM_ROOT}/av1/encoder/x86/pickrst_avx2.c")
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index 2d1da7f..2264b80 100644
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -361,6 +361,9 @@
   }
   add_proto qw/void av1_quantize_b/, "const tran_low_t *coeff_ptr, intptr_t n_coeffs, const int16_t *zbin_ptr, const int16_t *round_ptr, const int16_t *quant_ptr, const int16_t *quant_shift_ptr, tran_low_t *qcoeff_ptr, tran_low_t *dqcoeff_ptr, const int16_t *dequant_ptr, uint16_t *eob_ptr, const int16_t *scan, const int16_t *iscan, const qm_val_t * qm_ptr, const qm_val_t * iqm_ptr, int log_scale";
 
+add_proto qw/void av1_calc_indices_dim1/, "const int *data, const int *centroids, uint8_t *indices, int n, int k";
+specialize qw/av1_calc_indices_dim1 avx2/;
+
   # ENCODEMB INVOKE
   if (aom_config("CONFIG_AV1_HIGHBITDEPTH") eq "yes") {
     add_proto qw/int64_t av1_highbd_block_error/, "const tran_low_t *coeff, const tran_low_t *dqcoeff, intptr_t block_size, int64_t *ssz, int bd";
diff --git a/av1/encoder/k_means_template.h b/av1/encoder/k_means_template.h
index 9e526b8..1998a8a 100644
--- a/av1/encoder/k_means_template.h
+++ b/av1/encoder/k_means_template.h
@@ -95,7 +95,11 @@
   int pre_centroids[2 * PALETTE_MAX_SIZE];
   uint8_t pre_indices[MAX_SB_SQUARE];
 
+#if AV1_K_MEANS_DIM - 2
+  av1_calc_indices_dim1(data, centroids, indices, n, k);
+#else
   RENAME(av1_calc_indices)(data, centroids, indices, n, k);
+#endif
   int64_t this_dist = RENAME(calc_total_dist)(data, centroids, indices, n, k);
 
   for (int i = 0; i < max_itr; ++i) {
@@ -105,7 +109,11 @@
     memcpy(pre_indices, indices, sizeof(pre_indices[0]) * n);
 
     RENAME(calc_centroids)(data, centroids, indices, n, k);
+#if AV1_K_MEANS_DIM - 2
+    av1_calc_indices_dim1(data, centroids, indices, n, k);
+#else
     RENAME(av1_calc_indices)(data, centroids, indices, n, k);
+#endif
     this_dist = RENAME(calc_total_dist)(data, centroids, indices, n, k);
 
     if (this_dist > pre_dist) {
diff --git a/av1/encoder/palette.h b/av1/encoder/palette.h
index 0c77aa1..b1e1b14 100644
--- a/av1/encoder/palette.h
+++ b/av1/encoder/palette.h
@@ -26,11 +26,8 @@
 struct macroblock;
 
 /*!\cond */
-#define AV1_K_MEANS_RENAME(func, dim) func##_dim##dim
+#define AV1_K_MEANS_RENAME(func, dim) func##_dim##dim##_c
 
-void AV1_K_MEANS_RENAME(av1_calc_indices, 1)(const int *data,
-                                             const int *centroids,
-                                             uint8_t *indices, int n, int k);
 void AV1_K_MEANS_RENAME(av1_calc_indices, 2)(const int *data,
                                              const int *centroids,
                                              uint8_t *indices, int n, int k);
@@ -62,9 +59,9 @@
   assert(n > 0);
   assert(k > 0);
   if (dim == 1) {
-    AV1_K_MEANS_RENAME(av1_calc_indices, 1)(data, centroids, indices, n, k);
+    av1_calc_indices_dim1(data, centroids, indices, n, k);
   } else if (dim == 2) {
-    AV1_K_MEANS_RENAME(av1_calc_indices, 2)(data, centroids, indices, n, k);
+    av1_calc_indices_dim2_c(data, centroids, indices, n, k);
   } else {
     assert(0 && "Untemplated k means dimension");
   }
diff --git a/av1/encoder/x86/av1_k_means_avx2.c b/av1/encoder/x86/av1_k_means_avx2.c
new file mode 100644
index 0000000..a96ed2e
--- /dev/null
+++ b/av1/encoder/x86/av1_k_means_avx2.c
@@ -0,0 +1,50 @@
+/*
+ * Copyright (c) 2020, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+#include <immintrin.h>  // AVX2
+
+#include "config/aom_dsp_rtcd.h"
+#include "aom_dsp/x86/synonyms.h"
+
+void av1_calc_indices_dim1_avx2(const int *data, const int *centroids,
+                                uint8_t *indices, int n, int k) {
+  __m256i dist[PALETTE_MAX_SIZE];
+  __m256i v_zero = _mm256_setzero_si256();
+
+  for (int i = 0; i < n; i += 8) {
+    __m256i ind = _mm256_loadu_si256((__m256i *)data);
+    for (int j = 0; j < k; j++) {
+      __m256i cent = _mm256_set1_epi32((uint32_t)centroids[j]);
+      __m256i d1 = _mm256_sub_epi32(ind, cent);
+      dist[j] = _mm256_mullo_epi32(d1, d1);
+    }
+
+    ind = _mm256_setzero_si256();
+    for (int j = 1; j < k; j++) {
+      __m256i cmp = _mm256_cmpgt_epi32(dist[0], dist[j]);
+      __m256i dist1 = _mm256_andnot_si256(cmp, dist[0]);
+      __m256i dist2 = _mm256_and_si256(cmp, dist[j]);
+      dist[0] = _mm256_or_si256(dist1, dist2);
+      __m256i ind1 = _mm256_set1_epi32(j);
+      ind = _mm256_or_si256(_mm256_andnot_si256(cmp, ind),
+                            _mm256_and_si256(cmp, ind1));
+    }
+
+    __m256i p1 = _mm256_packus_epi32(ind, v_zero);
+    __m256i px = _mm256_permute4x64_epi64(p1, 0x58);
+    __m256i p2 = _mm256_packus_epi16(px, v_zero);
+    __m128i d1 = _mm256_extracti128_si256(p2, 0);
+
+    _mm_storel_epi64((__m128i *)indices, d1);
+
+    indices += 8;
+    data += 8;
+  }
+}
diff --git a/test/av1_k_means_test.cc b/test/av1_k_means_test.cc
new file mode 100644
index 0000000..86c72c1
--- /dev/null
+++ b/test/av1_k_means_test.cc
@@ -0,0 +1,158 @@
+/*
+ * Copyright (c) 2020, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#include <cstdlib>
+#include <new>
+#include <tuple>
+
+#include "config/aom_config.h"
+#include "config/av1_rtcd.h"
+
+#include "aom/aom_codec.h"
+#include "aom/aom_integer.h"
+#include "aom_mem/aom_mem.h"
+#include "aom_ports/aom_timer.h"
+#include "aom_ports/mem.h"
+#include "test/acm_random.h"
+#include "av1/encoder/palette.h"
+#include "test/clear_system_state.h"
+#include "test/register_state_check.h"
+#include "test/util.h"
+#include "third_party/googletest/src/googletest/include/gtest/gtest.h"
+
+namespace AV1Kmeans {
+typedef void (*av1_calc_indices_dim1_func)(const int *data,
+                                           const int *centroids,
+                                           uint8_t *indices, int n, int k);
+
+const BLOCK_SIZE kValidBlockSize[] = { BLOCK_8X8,   BLOCK_8X16,  BLOCK_8X32,
+                                       BLOCK_16X8,  BLOCK_16X16, BLOCK_16X32,
+                                       BLOCK_32X8,  BLOCK_32X16, BLOCK_32X32,
+                                       BLOCK_32X64, BLOCK_64X32, BLOCK_64X64,
+                                       BLOCK_16X64, BLOCK_64X16 };
+
+typedef std::tuple<av1_calc_indices_dim1_func, BLOCK_SIZE>
+    av1_calc_indices_dim1Param;
+
+class AV1KmeansTest
+    : public ::testing::TestWithParam<av1_calc_indices_dim1Param> {
+ public:
+  ~AV1KmeansTest();
+  void SetUp();
+
+  void TearDown();
+
+ protected:
+  void RunCheckOutput(av1_calc_indices_dim1_func test_impl, BLOCK_SIZE bsize,
+                      int centroids);
+  void RunSpeedTest(av1_calc_indices_dim1_func test_impl, BLOCK_SIZE bsize,
+                    int centroids);
+  bool CheckResult(int n) {
+    for (int idx = 0; idx < n; ++idx) {
+      if (indices1_[idx] != indices2_[idx]) {
+        printf("%d ", idx);
+        printf("%d != %d ", indices1_[idx], indices2_[idx]);
+        return false;
+      }
+    }
+    return true;
+  }
+
+  libaom_test::ACMRandom rnd_;
+  int data_[5096];
+  int centroids_[8];
+  uint8_t indices1_[5096];
+  uint8_t indices2_[5096];
+};
+GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(AV1KmeansTest);
+
+AV1KmeansTest::~AV1KmeansTest() { ; }
+
+void AV1KmeansTest::SetUp() {
+  rnd_.Reset(libaom_test::ACMRandom::DeterministicSeed());
+  /*uint8_t indices1_[5096];
+  uint8_t indices2_[5096];
+  int data_[5096];*/
+  for (int i = 0; i < 5096; ++i) {
+    data_[i] = (int)rnd_.Rand8() << 4;
+  }
+  for (int i = 0; i < 8; i++) {
+    centroids_[i] = (int)rnd_.Rand8() << 4;
+  }
+}
+
+void AV1KmeansTest::TearDown() { libaom_test::ClearSystemState(); }
+
+void AV1KmeansTest::RunCheckOutput(av1_calc_indices_dim1_func test_impl,
+                                   BLOCK_SIZE bsize, int k) {
+  const int w = block_size_wide[bsize];
+  const int h = block_size_high[bsize];
+  const int n = w * h;
+  av1_calc_indices_dim1_c(data_, centroids_, indices1_, n, k);
+  test_impl(data_, centroids_, indices2_, n, k);
+
+  ASSERT_EQ(CheckResult(n), true) << " block " << bsize << " Centroids " << n;
+}
+
+void AV1KmeansTest::RunSpeedTest(av1_calc_indices_dim1_func test_impl,
+                                 BLOCK_SIZE bsize, int k) {
+  const int w = block_size_wide[bsize];
+  const int h = block_size_high[bsize];
+  const int n = w * h;
+  const int num_loops = 1000000000 / n;
+
+  av1_calc_indices_dim1_func funcs[2] = { av1_calc_indices_dim1_c, test_impl };
+  double elapsed_time[2] = { 0 };
+  for (int i = 0; i < 2; ++i) {
+    aom_usec_timer timer;
+    aom_usec_timer_start(&timer);
+    av1_calc_indices_dim1_func func = funcs[i];
+    for (int j = 0; j < num_loops; ++j) {
+      func(data_, centroids_, indices1_, n, k);
+    }
+    aom_usec_timer_mark(&timer);
+    double time = static_cast<double>(aom_usec_timer_elapsed(&timer));
+    elapsed_time[i] = 1000.0 * time / num_loops;
+  }
+  printf("av1_calc_indices_dim1 indices= %d centroids=%d: %7.2f/%7.2fns", n, k,
+         elapsed_time[0], elapsed_time[1]);
+  printf("(%3.2f)\n", elapsed_time[0] / elapsed_time[1]);
+}
+
+TEST_P(AV1KmeansTest, CheckOutput) {
+  // centroids = 2..8
+  RunCheckOutput(GET_PARAM(0), GET_PARAM(1), 2);
+  RunCheckOutput(GET_PARAM(0), GET_PARAM(1), 3);
+  RunCheckOutput(GET_PARAM(0), GET_PARAM(1), 4);
+  RunCheckOutput(GET_PARAM(0), GET_PARAM(1), 5);
+  RunCheckOutput(GET_PARAM(0), GET_PARAM(1), 6);
+  RunCheckOutput(GET_PARAM(0), GET_PARAM(1), 7);
+  RunCheckOutput(GET_PARAM(0), GET_PARAM(1), 8);
+}
+
+TEST_P(AV1KmeansTest, DISABLED_Speed) {
+  RunSpeedTest(GET_PARAM(0), GET_PARAM(1), 2);
+  RunSpeedTest(GET_PARAM(0), GET_PARAM(1), 3);
+  RunSpeedTest(GET_PARAM(0), GET_PARAM(1), 4);
+  RunSpeedTest(GET_PARAM(0), GET_PARAM(1), 5);
+  RunSpeedTest(GET_PARAM(0), GET_PARAM(1), 6);
+  RunSpeedTest(GET_PARAM(0), GET_PARAM(1), 7);
+  RunSpeedTest(GET_PARAM(0), GET_PARAM(1), 8);
+}
+
+#if HAVE_AVX2
+INSTANTIATE_TEST_SUITE_P(
+    AVX2, AV1KmeansTest,
+    ::testing::Combine(::testing::Values(&av1_calc_indices_dim1_avx2),
+                       ::testing::ValuesIn(kValidBlockSize)));
+#endif
+
+}  // namespace AV1Kmeans
diff --git a/test/test.cmake b/test/test.cmake
index bbb6f2d..22e5ce8 100644
--- a/test/test.cmake
+++ b/test/test.cmake
@@ -235,7 +235,8 @@
               "${AOM_ROOT}/test/warp_filter_test.cc"
               "${AOM_ROOT}/test/warp_filter_test_util.cc"
               "${AOM_ROOT}/test/warp_filter_test_util.h"
-              "${AOM_ROOT}/test/webmenc_test.cc")
+              "${AOM_ROOT}/test/webmenc_test.cc"
+              "${AOM_ROOT}/test/av1_k_means_test.cc")
 
   if(CONFIG_REALTIME_ONLY)
     list(REMOVE_ITEM AOM_UNIT_TEST_ENCODER_SOURCES