K-means: use L1 norm when dim == 1.

In 1 dimension, finding the nearest neighbor using L2 or L1 is
equivalent. Only the best distance needs to be squared.

The speed-up is negligible: the goal is to make the code easier
to transition to 16-bit inputs.

Change-Id: Ibbff387aec1055655b4d96f3b28d0dcfddd80e6e
diff --git a/av1/encoder/k_means_template.h b/av1/encoder/k_means_template.h
index 9560da5..0cf72b7 100644
--- a/av1/encoder/k_means_template.h
+++ b/av1/encoder/k_means_template.h
@@ -11,6 +11,7 @@
 
 #include <assert.h>
 #include <stdint.h>
+#include <stdlib.h>
 #include <string.h>
 
 #include "av1/common/blockd.h"
@@ -24,13 +25,20 @@
 #define RENAME_(x, y) AV1_K_MEANS_RENAME(x, y)
 #define RENAME(x) RENAME_(x, AV1_K_MEANS_DIM)
 
+// Though we want to compute the smallest L2 norm, in 1 dimension,
+// it is equivalent to find the smallest L1 norm and then square it.
+// This is preferrable for speed, especially on the SIMD side.
 static int RENAME(calc_dist)(const int *p1, const int *p2) {
+#if AV1_K_MEANS_DIM == 1
+  return abs(p1[0] - p2[0]);
+#else
   int dist = 0;
   for (int i = 0; i < AV1_K_MEANS_DIM; ++i) {
     const int diff = p1[i] - p2[i];
     dist += diff * diff;
   }
   return dist;
+#endif
 }
 
 void RENAME(av1_calc_indices)(const int *data, const int *centroids,
@@ -50,7 +58,11 @@
       }
     }
     if (dist) {
+#if AV1_K_MEANS_DIM == 1
+      *dist += min_dist * min_dist;
+#else
       *dist += min_dist;
+#endif
     }
   }
 }
@@ -97,7 +109,7 @@
 
   assert(n <= MAX_PALETTE_BLOCK_WIDTH * MAX_PALETTE_BLOCK_HEIGHT);
 
-#if AV1_K_MEANS_DIM - 2
+#if AV1_K_MEANS_DIM == 1
   av1_calc_indices_dim1(data, centroids, indices, &this_dist, n, k);
 #else
   av1_calc_indices_dim2(data, centroids, indices, &this_dist, n, k);
@@ -109,7 +121,7 @@
     l = (l == 1) ? 0 : 1;
 
     RENAME(calc_centroids)(data, meta_centroids[l], meta_indices[prev_l], n, k);
-#if AV1_K_MEANS_DIM - 2
+#if AV1_K_MEANS_DIM == 1
     av1_calc_indices_dim1(data, meta_centroids[l], meta_indices[l], &this_dist,
                           n, k);
 #else
diff --git a/av1/encoder/x86/av1_k_means_avx2.c b/av1/encoder/x86/av1_k_means_avx2.c
index 2745ac1..854b90c 100644
--- a/av1/encoder/x86/av1_k_means_avx2.c
+++ b/av1/encoder/x86/av1_k_means_avx2.c
@@ -35,7 +35,7 @@
     for (int j = 0; j < k; j++) {
       __m256i cent = _mm256_set1_epi32(centroids[j]);
       __m256i d1 = _mm256_sub_epi32(ind, cent);
-      dist[j] = _mm256_mullo_epi32(d1, d1);
+      dist[j] = _mm256_abs_epi32(d1);
     }
 
     ind = _mm256_setzero_si256();
@@ -57,7 +57,8 @@
     _mm_storel_epi64((__m128i *)indices, d1);
 
     if (total_dist) {
-      // Convert to 64 bit and add to sum.
+      // Square, convert to 64 bit and add to sum.
+      dist[0] = _mm256_mullo_epi32(dist[0], dist[0]);
       const __m256i dist1 = _mm256_unpacklo_epi32(dist[0], v_zero);
       const __m256i dist2 = _mm256_unpackhi_epi32(dist[0], v_zero);
       sum = _mm256_add_epi64(sum, dist1);
diff --git a/av1/encoder/x86/av1_k_means_sse2.c b/av1/encoder/x86/av1_k_means_sse2.c
index 2c12346..d2c7796 100644
--- a/av1/encoder/x86/av1_k_means_sse2.c
+++ b/av1/encoder/x86/av1_k_means_sse2.c
@@ -22,6 +22,15 @@
   return res;
 }
 
+static __m128i absolute_diff_epi32(__m128i a, __m128i b) {
+  const __m128i diff1 = _mm_sub_epi32(a, b);
+  const __m128i diff2 = _mm_sub_epi32(b, a);
+  const __m128i cmp = _mm_cmpgt_epi32(diff1, diff2);
+  const __m128i masked1 = _mm_and_si128(cmp, diff1);
+  const __m128i masked2 = _mm_andnot_si128(cmp, diff2);
+  return _mm_or_si128(masked1, masked2);
+}
+
 void av1_calc_indices_dim1_sse2(const int *data, const int *centroids,
                                 uint8_t *indices, int64_t *total_dist, int n,
                                 int k) {
@@ -36,11 +45,7 @@
     ind[l] = _mm_loadu_si128((__m128i *)data);
     for (int j = 0; j < k; j++) {
       __m128i cent = _mm_set1_epi32(centroids[j]);
-      __m128i d1 = _mm_sub_epi32(ind[l], cent);
-      __m128i d2 = _mm_packs_epi32(d1, d1);
-      __m128i d3 = _mm_mullo_epi16(d2, d2);
-      __m128i d4 = _mm_mulhi_epi16(d2, d2);
-      dist[j] = _mm_unpacklo_epi16(d3, d4);
+      dist[j] = absolute_diff_epi32(ind[l], cent);
     }
 
     ind[l] = _mm_setzero_si128();
@@ -55,6 +60,11 @@
     }
     ind[l] = _mm_packus_epi16(ind[l], v_zero);
     if (total_dist) {
+      // Square and convert to 32 bit.
+      const __m128i d1 = _mm_packs_epi32(dist[0], v_zero);
+      const __m128i d2 = _mm_mullo_epi16(d1, d1);
+      const __m128i d3 = _mm_mulhi_epi16(d1, d1);
+      dist[0] = _mm_unpacklo_epi16(d2, d3);
       // Convert to 64 bit and add to sum.
       const __m128i dist1 = _mm_unpacklo_epi32(dist[0], v_zero);
       const __m128i dist2 = _mm_unpackhi_epi32(dist[0], v_zero);
@@ -73,15 +83,6 @@
   }
 }
 
-static __m128i absolute_diff_epi32(__m128i a, __m128i b) {
-  const __m128i diff1 = _mm_sub_epi32(a, b);
-  const __m128i diff2 = _mm_sub_epi32(b, a);
-  const __m128i cmp = _mm_cmpgt_epi32(diff1, diff2);
-  const __m128i masked1 = _mm_and_si128(cmp, diff1);
-  const __m128i masked2 = _mm_andnot_si128(cmp, diff2);
-  return _mm_or_si128(masked1, masked2);
-}
-
 void av1_calc_indices_dim2_sse2(const int *data, const int *centroids,
                                 uint8_t *indices, int64_t *total_dist, int n,
                                 int k) {