K-means: misc refactoring and speed-ups.

Previous suggestions are implemented:
- store centroids once and for all
- do not store all distance: only the current one and the minimum one
- store twice more indices in 2d AVX2

On rtc_screen, speed 0, 50 runs, speed-ups are in %:

File	encoding_spdup:
screen_recording_crd.1920_1080.y4m	0.289
screenshare_buganizer.1900_1306.y4m	0.136
screenshare_colorslides.1820_1320.y4m	0.232
screenshare_slidechanges.1850_1110.y4m	0.110
screenshare_youtube.1680_1178.y4m	0.534
slides_webplot.1920_1080.y4m	0.112
sc_web_browsing720p.y4m	0.548
screen_crd_colwinscroll.1920_1128.y4m	0.225
{OVERALL}	0.273

Speed 10:

File	encoding_spdup:
screen_recording_crd.1920_1080.y4m	0.277
screenshare_buganizer.1900_1306.y4m	0.382
screenshare_colorslides.1820_1320.y4m	0.391
screenshare_slidechanges.1850_1110.y4m	0.341
screenshare_youtube.1680_1178.y4m	0.813
slides_webplot.1920_1080.y4m	0.287
sc_web_browsing720p.y4m	0.611
screen_crd_colwinscroll.1920_1128.y4m	0.234
{OVERALL}	0.417

Change-Id: I7178ec5c89c9667aac1818c9eb1eda33ef8f81a9
diff --git a/av1/encoder/x86/av1_k_means_avx2.c b/av1/encoder/x86/av1_k_means_avx2.c
index a2db222..ad0b374 100644
--- a/av1/encoder/x86/av1_k_means_avx2.c
+++ b/av1/encoder/x86/av1_k_means_avx2.c
@@ -26,39 +26,44 @@
 void av1_calc_indices_dim1_avx2(const int16_t *data, const int16_t *centroids,
                                 uint8_t *indices, int64_t *total_dist, int n,
                                 int k) {
-  __m256i dist[PALETTE_MAX_SIZE];
   const __m256i v_zero = _mm256_setzero_si256();
   __m256i sum = _mm256_setzero_si256();
+  __m256i cents[PALETTE_MAX_SIZE];
+  for (int j = 0; j < k; ++j) {
+    cents[j] = _mm256_set1_epi16(centroids[j]);
+  }
 
   for (int i = 0; i < n; i += 16) {
-    __m256i ind = _mm256_loadu_si256((__m256i *)data);
-    for (int j = 0; j < k; j++) {
-      __m256i cent = _mm256_set1_epi16(centroids[j]);
-      __m256i d1 = _mm256_sub_epi16(ind, cent);
-      dist[j] = _mm256_abs_epi16(d1);
-    }
+    const __m256i in = _mm256_loadu_si256((__m256i *)data);
+    __m256i ind = _mm256_setzero_si256();
+    // Compute the distance to the first centroid.
+    __m256i d1 = _mm256_sub_epi16(in, cents[0]);
+    __m256i dist_min = _mm256_abs_epi16(d1);
 
-    ind = _mm256_setzero_si256();
-    for (int j = 1; j < k; j++) {
-      __m256i cmp = _mm256_cmpgt_epi16(dist[0], dist[j]);
-      dist[0] = _mm256_min_epi16(dist[0], dist[j]);
-      __m256i ind1 = _mm256_set1_epi16(j);
+    for (int j = 1; j < k; ++j) {
+      // Compute the distance to the centroid.
+      d1 = _mm256_sub_epi16(in, cents[j]);
+      const __m256i dist = _mm256_abs_epi16(d1);
+      // Compare to the minimal one.
+      const __m256i cmp = _mm256_cmpgt_epi16(dist_min, dist);
+      dist_min = _mm256_min_epi16(dist_min, dist);
+      const __m256i ind1 = _mm256_set1_epi16(j);
       ind = _mm256_or_si256(_mm256_andnot_si256(cmp, ind),
                             _mm256_and_si256(cmp, ind1));
     }
 
-    __m256i p1 = _mm256_packus_epi16(ind, v_zero);
-    __m256i px = _mm256_permute4x64_epi64(p1, 0x58);
-    __m128i d1 = _mm256_extracti128_si256(px, 0);
+    const __m256i p1 = _mm256_packus_epi16(ind, v_zero);
+    const __m256i px = _mm256_permute4x64_epi64(p1, 0x58);
+    const __m128i d2 = _mm256_extracti128_si256(px, 0);
 
-    _mm_storeu_si128((__m128i *)indices, d1);
+    _mm_storeu_si128((__m128i *)indices, d2);
 
     if (total_dist) {
       // Square, convert to 32 bit and add together.
-      dist[0] = _mm256_madd_epi16(dist[0], dist[0]);
+      dist_min = _mm256_madd_epi16(dist_min, dist_min);
       // Convert to 64 bit and add to sum.
-      const __m256i dist1 = _mm256_unpacklo_epi32(dist[0], v_zero);
-      const __m256i dist2 = _mm256_unpackhi_epi32(dist[0], v_zero);
+      const __m256i dist1 = _mm256_unpacklo_epi32(dist_min, v_zero);
+      const __m256i dist2 = _mm256_unpackhi_epi32(dist_min, v_zero);
       sum = _mm256_add_epi64(sum, dist1);
       sum = _mm256_add_epi64(sum, dist2);
     }
@@ -74,46 +79,52 @@
 void av1_calc_indices_dim2_avx2(const int16_t *data, const int16_t *centroids,
                                 uint8_t *indices, int64_t *total_dist, int n,
                                 int k) {
-  __m256i dist[PALETTE_MAX_SIZE];
   const __m256i v_zero = _mm256_setzero_si256();
+  const __m256i permute = _mm256_set_epi32(0, 0, 0, 0, 5, 1, 4, 0);
   __m256i sum = _mm256_setzero_si256();
+  __m256i ind[2];
+  __m256i cents[PALETTE_MAX_SIZE];
+  for (int j = 0; j < k; ++j) {
+    const int16_t cx = centroids[2 * j], cy = centroids[2 * j + 1];
+    cents[j] = _mm256_set_epi16(cy, cx, cy, cx, cy, cx, cy, cx, cy, cx, cy, cx,
+                                cy, cx, cy, cx);
+  }
 
-  for (int i = 0; i < n; i += 8) {
-    __m256i ind = _mm256_loadu_si256((__m256i *)data);
-    for (int j = 0; j < k; j++) {
-      const int16_t cx = centroids[2 * j], cy = centroids[2 * j + 1];
-      const __m256i cent = _mm256_set_epi16(cy, cx, cy, cx, cy, cx, cy, cx, cy,
-                                            cx, cy, cx, cy, cx, cy, cx);
-      const __m256i d1 = _mm256_sub_epi16(ind, cent);
-      dist[j] = _mm256_madd_epi16(d1, d1);
+  for (int i = 0; i < n; i += 16) {
+    for (int l = 0; l < 2; ++l) {
+      const __m256i in = _mm256_loadu_si256((__m256i *)data);
+      ind[l] = _mm256_setzero_si256();
+      // Compute the distance to the first centroid.
+      __m256i d1 = _mm256_sub_epi16(in, cents[0]);
+      __m256i dist_min = _mm256_madd_epi16(d1, d1);
+
+      for (int j = 1; j < k; ++j) {
+        // Compute the distance to the centroid.
+        d1 = _mm256_sub_epi16(in, cents[j]);
+        const __m256i dist = _mm256_madd_epi16(d1, d1);
+        // Compare to the minimal one.
+        const __m256i cmp = _mm256_cmpgt_epi32(dist_min, dist);
+        dist_min = _mm256_min_epi32(dist_min, dist);
+        const __m256i ind1 = _mm256_set1_epi32(j);
+        ind[l] = _mm256_or_si256(_mm256_andnot_si256(cmp, ind[l]),
+                                 _mm256_and_si256(cmp, ind1));
+      }
+      if (total_dist) {
+        // Convert to 64 bit and add to sum.
+        const __m256i dist1 = _mm256_unpacklo_epi32(dist_min, v_zero);
+        const __m256i dist2 = _mm256_unpackhi_epi32(dist_min, v_zero);
+        sum = _mm256_add_epi64(sum, dist1);
+        sum = _mm256_add_epi64(sum, dist2);
+      }
+      data += 16;
     }
-
-    ind = _mm256_setzero_si256();
-    for (int j = 1; j < k; j++) {
-      __m256i cmp = _mm256_cmpgt_epi32(dist[0], dist[j]);
-      dist[0] = _mm256_min_epi32(dist[0], dist[j]);
-      const __m256i ind1 = _mm256_set1_epi32(j);
-      ind = _mm256_or_si256(_mm256_andnot_si256(cmp, ind),
-                            _mm256_and_si256(cmp, ind1));
-    }
-
-    __m256i p1 = _mm256_packus_epi32(ind, v_zero);
-    __m256i px = _mm256_permute4x64_epi64(p1, 0x58);
-    __m256i p2 = _mm256_packus_epi16(px, v_zero);
-    __m128i d1 = _mm256_extracti128_si256(p2, 0);
-
-    _mm_storel_epi64((__m128i *)indices, d1);
-
-    if (total_dist) {
-      // Convert to 64 bit and add to sum.
-      const __m256i dist1 = _mm256_unpacklo_epi32(dist[0], v_zero);
-      const __m256i dist2 = _mm256_unpackhi_epi32(dist[0], v_zero);
-      sum = _mm256_add_epi64(sum, dist1);
-      sum = _mm256_add_epi64(sum, dist2);
-    }
-
-    indices += 8;
-    data += 16;
+    // Cast to 8 bit and store.
+    const __m256i d2 = _mm256_packus_epi32(ind[0], ind[1]);
+    const __m256i d3 = _mm256_packus_epi16(d2, v_zero);
+    const __m256i d4 = _mm256_permutevar8x32_epi32(d3, permute);
+    const __m128i d5 = _mm256_extracti128_si256(d4, 0);
+    _mm_storeu_si128((__m128i *)indices, d5);
+    indices += 16;
   }
   if (total_dist) {
     *total_dist = k_means_horizontal_sum_avx2(sum);
diff --git a/av1/encoder/x86/av1_k_means_sse2.c b/av1/encoder/x86/av1_k_means_sse2.c
index a284fa9..4338bf7 100644
--- a/av1/encoder/x86/av1_k_means_sse2.c
+++ b/av1/encoder/x86/av1_k_means_sse2.c
@@ -26,31 +26,37 @@
                                 uint8_t *indices, int64_t *total_dist, int n,
                                 int k) {
   const __m128i v_zero = _mm_setzero_si128();
-  __m128i dist[PALETTE_MAX_SIZE];
   __m128i sum = _mm_setzero_si128();
+  __m128i cents[PALETTE_MAX_SIZE];
+  for (int j = 0; j < k; ++j) {
+    cents[j] = _mm_set1_epi16(centroids[j]);
+  }
 
   for (int i = 0; i < n; i += 8) {
-    __m128i in = _mm_loadu_si128((__m128i *)data);
-    for (int j = 0; j < k; j++) {
-      __m128i cent = _mm_set1_epi16(centroids[j]);
-      __m128i d1 = _mm_sub_epi16(in, cent);
-      __m128i d2 = _mm_sub_epi16(cent, in);
-      dist[j] = _mm_max_epi16(d1, d2);
-    }
-
+    const __m128i in = _mm_loadu_si128((__m128i *)data);
     __m128i ind = _mm_setzero_si128();
-    for (int j = 1; j < k; j++) {
-      __m128i cmp = _mm_cmpgt_epi16(dist[0], dist[j]);
-      dist[0] = _mm_min_epi16(dist[0], dist[j]);
-      __m128i ind1 = _mm_set1_epi16(j);
+    // Compute the distance to the first centroid.
+    __m128i d1 = _mm_sub_epi16(in, cents[0]);
+    __m128i d2 = _mm_sub_epi16(cents[0], in);
+    __m128i dist_min = _mm_max_epi16(d1, d2);
+
+    for (int j = 1; j < k; ++j) {
+      // Compute the distance to the centroid.
+      d1 = _mm_sub_epi16(in, cents[j]);
+      d2 = _mm_sub_epi16(cents[j], in);
+      const __m128i dist = _mm_max_epi16(d1, d2);
+      // Compare to the minimal one.
+      const __m128i cmp = _mm_cmpgt_epi16(dist_min, dist);
+      dist_min = _mm_min_epi16(dist_min, dist);
+      const __m128i ind1 = _mm_set1_epi16(j);
       ind = _mm_or_si128(_mm_andnot_si128(cmp, ind), _mm_and_si128(cmp, ind1));
     }
     if (total_dist) {
       // Square, convert to 32 bit and add together.
-      dist[0] = _mm_madd_epi16(dist[0], dist[0]);
+      dist_min = _mm_madd_epi16(dist_min, dist_min);
       // Convert to 64 bit and add to sum.
-      const __m128i dist1 = _mm_unpacklo_epi32(dist[0], v_zero);
-      const __m128i dist2 = _mm_unpackhi_epi32(dist[0], v_zero);
+      const __m128i dist1 = _mm_unpacklo_epi32(dist_min, v_zero);
+      const __m128i dist2 = _mm_unpackhi_epi32(dist_min, v_zero);
       sum = _mm_add_epi64(sum, dist1);
       sum = _mm_add_epi64(sum, dist2);
     }
@@ -68,45 +74,49 @@
                                 uint8_t *indices, int64_t *total_dist, int n,
                                 int k) {
   const __m128i v_zero = _mm_setzero_si128();
-  int l = 1;
-  __m128i dist[PALETTE_MAX_SIZE];
-  __m128i ind[2];
   __m128i sum = _mm_setzero_si128();
+  __m128i ind[2];
+  __m128i cents[PALETTE_MAX_SIZE];
+  for (int j = 0; j < k; ++j) {
+    const int16_t cx = centroids[2 * j], cy = centroids[2 * j + 1];
+    cents[j] = _mm_set_epi16(cy, cx, cy, cx, cy, cx, cy, cx);
+  }
 
-  for (int i = 0; i < n; i += 4) {
-    l = (l == 0) ? 1 : 0;
-    __m128i ind1 = _mm_loadu_si128((__m128i *)data);
-    for (int j = 0; j < k; j++) {
-      const int16_t cx = centroids[2 * j], cy = centroids[2 * j + 1];
-      const __m128i cent = _mm_set_epi16(cy, cx, cy, cx, cy, cx, cy, cx);
-      const __m128i d1 = _mm_sub_epi16(ind1, cent);
-      dist[j] = _mm_madd_epi16(d1, d1);
-    }
+  for (int i = 0; i < n; i += 8) {
+    for (int l = 0; l < 2; ++l) {
+      const __m128i in = _mm_loadu_si128((__m128i *)data);
+      ind[l] = _mm_setzero_si128();
+      // Compute the distance to the first centroid.
+      __m128i d1 = _mm_sub_epi16(in, cents[0]);
+      __m128i dist_min = _mm_madd_epi16(d1, d1);
 
-    ind[l] = _mm_setzero_si128();
-    for (int j = 1; j < k; j++) {
-      __m128i cmp = _mm_cmpgt_epi32(dist[0], dist[j]);
-      __m128i dist1 = _mm_andnot_si128(cmp, dist[0]);
-      __m128i dist2 = _mm_and_si128(cmp, dist[j]);
-      dist[0] = _mm_or_si128(dist1, dist2);
-      ind1 = _mm_set1_epi32(j);
-      ind[l] =
-          _mm_or_si128(_mm_andnot_si128(cmp, ind[l]), _mm_and_si128(cmp, ind1));
+      for (int j = 1; j < k; ++j) {
+        // Compute the distance to the centroid.
+        d1 = _mm_sub_epi16(in, cents[j]);
+        const __m128i dist = _mm_madd_epi16(d1, d1);
+        // Compare to the minimal one.
+        const __m128i cmp = _mm_cmpgt_epi32(dist_min, dist);
+        const __m128i dist1 = _mm_andnot_si128(cmp, dist_min);
+        const __m128i dist2 = _mm_and_si128(cmp, dist);
+        dist_min = _mm_or_si128(dist1, dist2);
+        const __m128i ind1 = _mm_set1_epi32(j);
+        ind[l] = _mm_or_si128(_mm_andnot_si128(cmp, ind[l]),
+                              _mm_and_si128(cmp, ind1));
+      }
+      if (total_dist) {
+        // Convert to 64 bit and add to sum.
+        const __m128i dist1 = _mm_unpacklo_epi32(dist_min, v_zero);
+        const __m128i dist2 = _mm_unpackhi_epi32(dist_min, v_zero);
+        sum = _mm_add_epi64(sum, dist1);
+        sum = _mm_add_epi64(sum, dist2);
+      }
+      data += 8;
     }
-    ind[l] = _mm_packus_epi16(ind[l], v_zero);
-    if (total_dist) {
-      // Convert to 64 bit and add to sum.
-      const __m128i dist1 = _mm_unpacklo_epi32(dist[0], v_zero);
-      const __m128i dist2 = _mm_unpackhi_epi32(dist[0], v_zero);
-      sum = _mm_add_epi64(sum, dist1);
-      sum = _mm_add_epi64(sum, dist2);
-    }
-    if (l == 1) {
-      __m128i p2 = _mm_packus_epi16(_mm_unpacklo_epi64(ind[0], ind[1]), v_zero);
-      _mm_storel_epi64((__m128i *)indices, p2);
-      indices += 8;
-    }
-    data += 8;
+    // Cast to 8 bit and store.
+    const __m128i d2 = _mm_packus_epi16(ind[0], ind[1]);
+    const __m128i d3 = _mm_packus_epi16(d2, v_zero);
+    _mm_storel_epi64((__m128i *)indices, d3);
+    indices += 8;
   }
   if (total_dist) {
     *total_dist = k_means_horizontal_sum_sse2(sum);