K-means: switch pipeline to 16 bit.

Less RAM, less for loops in SIMD, simpler SIMD instructions.

On rtc_screen, speed 0, 50 runs, speed-ups are in %:

screen_recording_crd.1920_1080.y4m	0.483
screenshare_buganizer.1900_1306.y4m	0.718
screenshare_colorslides.1820_1320.y4m	0.913
screenshare_slidechanges.1850_1110.y4m	0.576
screenshare_youtube.1680_1178.y4m	0.465
slides_webplot.1920_1080.y4m	0.475
sc_web_browsing720p.y4m	0.690
screen_crd_colwinscroll.1920_1128.y4m	0.146
{OVERALL}	0.558

At speed 10:

screen_recording_crd.1920_1080.y4m	0.347
screenshare_buganizer.1900_1306.y4m	1.145
screenshare_colorslides.1820_1320.y4m	1.028
screenshare_slidechanges.1850_1110.y4m	0.945
screenshare_youtube.1680_1178.y4m	0.366
slides_webplot.1920_1080.y4m	0.999
sc_web_browsing720p.y4m	1.261
screen_crd_colwinscroll.1920_1128.y4m	0.133
{OVERALL}	0.778

Change-Id: I009d31e749c54474b56ff36aec520f686d728f9e
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index b5baaae..ba1dcbb 100644
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -412,10 +412,10 @@
 
   add_proto qw/void av1_quantize_b/, "const tran_low_t *coeff_ptr, intptr_t n_coeffs, const int16_t *zbin_ptr, const int16_t *round_ptr, const int16_t *quant_ptr, const int16_t *quant_shift_ptr, tran_low_t *qcoeff_ptr, tran_low_t *dqcoeff_ptr, const int16_t *dequant_ptr, uint16_t *eob_ptr, const int16_t *scan, const int16_t *iscan, const qm_val_t * qm_ptr, const qm_val_t * iqm_ptr, int log_scale";
 
-  add_proto qw/void av1_calc_indices_dim1/, "const int *data, const int *centroids, uint8_t *indices, int64_t *total_dist, int n, int k";
+  add_proto qw/void av1_calc_indices_dim1/, "const int16_t *data, const int16_t *centroids, uint8_t *indices, int64_t *total_dist, int n, int k";
   specialize qw/av1_calc_indices_dim1 sse2 avx2/;
 
-  add_proto qw/void av1_calc_indices_dim2/, "const int *data, const int *centroids, uint8_t *indices, int64_t *total_dist, int n, int k";
+  add_proto qw/void av1_calc_indices_dim2/, "const int16_t *data, const int16_t *centroids, uint8_t *indices, int64_t *total_dist, int n, int k";
   specialize qw/av1_calc_indices_dim2 sse2 avx2/;
 
   # ENCODEMB INVOKE
diff --git a/av1/encoder/block.h b/av1/encoder/block.h
index 3dec881..4185798 100644
--- a/av1/encoder/block.h
+++ b/av1/encoder/block.h
@@ -326,7 +326,7 @@
   //! The best color map found.
   uint8_t best_palette_color_map[MAX_PALETTE_SQUARE];
   //! A temporary buffer used for k-means clustering.
-  int kmeans_data_buf[2 * MAX_PALETTE_SQUARE];
+  int16_t kmeans_data_buf[2 * MAX_PALETTE_SQUARE];
 } PALETTE_BUFFER;
 
 /*! \brief Contains buffers used by av1_compound_type_rd()
diff --git a/av1/encoder/k_means_template.h b/av1/encoder/k_means_template.h
index 0cf72b7..31ffdcf 100644
--- a/av1/encoder/k_means_template.h
+++ b/av1/encoder/k_means_template.h
@@ -28,7 +28,7 @@
 // Though we want to compute the smallest L2 norm, in 1 dimension,
 // it is equivalent to find the smallest L1 norm and then square it.
 // This is preferrable for speed, especially on the SIMD side.
-static int RENAME(calc_dist)(const int *p1, const int *p2) {
+static int RENAME(calc_dist)(const int16_t *p1, const int16_t *p2) {
 #if AV1_K_MEANS_DIM == 1
   return abs(p1[0] - p2[0]);
 #else
@@ -41,7 +41,7 @@
 #endif
 }
 
-void RENAME(av1_calc_indices)(const int *data, const int *centroids,
+void RENAME(av1_calc_indices)(const int16_t *data, const int16_t *centroids,
                               uint8_t *indices, int64_t *dist, int n, int k) {
   if (dist) {
     *dist = 0;
@@ -67,20 +67,22 @@
   }
 }
 
-static void RENAME(calc_centroids)(const int *data, int *centroids,
+static void RENAME(calc_centroids)(const int16_t *data, int16_t *centroids,
                                    const uint8_t *indices, int n, int k) {
   int i, j;
   int count[PALETTE_MAX_SIZE] = { 0 };
+  int centroids_sum[AV1_K_MEANS_DIM * PALETTE_MAX_SIZE];
   unsigned int rand_state = (unsigned int)data[0];
   assert(n <= 32768);
-  memset(centroids, 0, sizeof(centroids[0]) * k * AV1_K_MEANS_DIM);
+  memset(centroids_sum, 0, sizeof(centroids_sum[0]) * k * AV1_K_MEANS_DIM);
 
   for (i = 0; i < n; ++i) {
     const int index = indices[i];
     assert(index < k);
     ++count[index];
     for (j = 0; j < AV1_K_MEANS_DIM; ++j) {
-      centroids[index * AV1_K_MEANS_DIM + j] += data[i * AV1_K_MEANS_DIM + j];
+      centroids_sum[index * AV1_K_MEANS_DIM + j] +=
+          data[i * AV1_K_MEANS_DIM + j];
     }
   }
 
@@ -92,17 +94,17 @@
     } else {
       for (j = 0; j < AV1_K_MEANS_DIM; ++j) {
         centroids[i * AV1_K_MEANS_DIM + j] =
-            DIVIDE_AND_ROUND(centroids[i * AV1_K_MEANS_DIM + j], count[i]);
+            DIVIDE_AND_ROUND(centroids_sum[i * AV1_K_MEANS_DIM + j], count[i]);
       }
     }
   }
 }
 
-void RENAME(av1_k_means)(const int *data, int *centroids, uint8_t *indices,
-                         int n, int k, int max_itr) {
-  int centroids_tmp[AV1_K_MEANS_DIM * PALETTE_MAX_SIZE];
+void RENAME(av1_k_means)(const int16_t *data, int16_t *centroids,
+                         uint8_t *indices, int n, int k, int max_itr) {
+  int16_t centroids_tmp[AV1_K_MEANS_DIM * PALETTE_MAX_SIZE];
   uint8_t indices_tmp[MAX_PALETTE_BLOCK_WIDTH * MAX_PALETTE_BLOCK_HEIGHT];
-  int *meta_centroids[2] = { centroids, centroids_tmp };
+  int16_t *meta_centroids[2] = { centroids, centroids_tmp };
   uint8_t *meta_indices[2] = { indices, indices_tmp };
   int i, l = 0, prev_l, best_l = 0;
   int64_t this_dist;
diff --git a/av1/encoder/palette.c b/av1/encoder/palette.c
index 4375175..9c3d407 100644
--- a/av1/encoder/palette.c
+++ b/av1/encoder/palette.c
@@ -31,14 +31,14 @@
 #include "av1/encoder/k_means_template.h"
 #undef AV1_K_MEANS_DIM
 
-static int int_comparer(const void *a, const void *b) {
-  return (*(int *)a - *(int *)b);
+static int int16_comparer(const void *a, const void *b) {
+  return (*(int16_t *)a - *(int16_t *)b);
 }
 
-int av1_remove_duplicates(int *centroids, int num_centroids) {
+int av1_remove_duplicates(int16_t *centroids, int num_centroids) {
   int num_unique;  // number of unique centroids
   int i;
-  qsort(centroids, num_centroids, sizeof(*centroids), int_comparer);
+  qsort(centroids, num_centroids, sizeof(*centroids), int16_comparer);
   // Remove duplicates.
   num_unique = 1;
   for (i = 1; i < num_centroids; ++i) {
@@ -189,14 +189,14 @@
 // TODO(huisu): Try other schemes to improve compression.
 static AOM_INLINE void optimize_palette_colors(uint16_t *color_cache,
                                                int n_cache, int n_colors,
-                                               int stride, int *centroids,
+                                               int stride, int16_t *centroids,
                                                int bit_depth) {
   if (n_cache <= 0) return;
   for (int i = 0; i < n_colors * stride; i += stride) {
-    int min_diff = abs(centroids[i] - (int)color_cache[0]);
+    int min_diff = abs((int)centroids[i] - (int)color_cache[0]);
     int idx = 0;
     for (int j = 1; j < n_cache; ++j) {
-      const int this_diff = abs(centroids[i] - color_cache[j]);
+      const int this_diff = abs((int)centroids[i] - (int)color_cache[j]);
       if (this_diff < min_diff) {
         min_diff = this_diff;
         idx = j;
@@ -216,8 +216,8 @@
  */
 static AOM_INLINE void palette_rd_y(
     const AV1_COMP *const cpi, MACROBLOCK *x, MB_MODE_INFO *mbmi,
-    BLOCK_SIZE bsize, int dc_mode_cost, const int *data, int *centroids, int n,
-    uint16_t *color_cache, int n_cache, bool do_header_rd_based_gating,
+    BLOCK_SIZE bsize, int dc_mode_cost, const int16_t *data, int16_t *centroids,
+    int n, uint16_t *color_cache, int n_cache, bool do_header_rd_based_gating,
     MB_MODE_INFO *best_mbmi, uint8_t *best_palette_color_map, int64_t *best_rd,
     int *rate, int *rate_tokenonly, int64_t *distortion, uint8_t *skippable,
     int *beat_best_rd, PICK_MODE_CONTEXT *ctx, uint8_t *blk_skip,
@@ -324,14 +324,14 @@
 // returns the best number of colors found.
 static AOM_INLINE int perform_top_color_palette_search(
     const AV1_COMP *const cpi, MACROBLOCK *x, MB_MODE_INFO *mbmi,
-    BLOCK_SIZE bsize, int dc_mode_cost, const int *data, int *top_colors,
-    int start_n, int end_n, int step_size, bool do_header_rd_based_gating,
-    int *last_n_searched, uint16_t *color_cache, int n_cache,
-    MB_MODE_INFO *best_mbmi, uint8_t *best_palette_color_map, int64_t *best_rd,
-    int *rate, int *rate_tokenonly, int64_t *distortion, uint8_t *skippable,
-    int *beat_best_rd, PICK_MODE_CONTEXT *ctx, uint8_t *best_blk_skip,
-    uint8_t *tx_type_map, int discount_color_cost) {
-  int centroids[PALETTE_MAX_SIZE];
+    BLOCK_SIZE bsize, int dc_mode_cost, const int16_t *data,
+    int16_t *top_colors, int start_n, int end_n, int step_size,
+    bool do_header_rd_based_gating, int *last_n_searched, uint16_t *color_cache,
+    int n_cache, MB_MODE_INFO *best_mbmi, uint8_t *best_palette_color_map,
+    int64_t *best_rd, int *rate, int *rate_tokenonly, int64_t *distortion,
+    uint8_t *skippable, int *beat_best_rd, PICK_MODE_CONTEXT *ctx,
+    uint8_t *best_blk_skip, uint8_t *tx_type_map, int discount_color_cost) {
+  int16_t centroids[PALETTE_MAX_SIZE];
   int n = start_n;
   int top_color_winner = end_n;
   /* clang-format off */
@@ -371,7 +371,7 @@
 // returns the best number of colors found.
 static AOM_INLINE int perform_k_means_palette_search(
     const AV1_COMP *const cpi, MACROBLOCK *x, MB_MODE_INFO *mbmi,
-    BLOCK_SIZE bsize, int dc_mode_cost, const int *data, int lower_bound,
+    BLOCK_SIZE bsize, int dc_mode_cost, const int16_t *data, int lower_bound,
     int upper_bound, int start_n, int end_n, int step_size,
     bool do_header_rd_based_gating, int *last_n_searched, uint16_t *color_cache,
     int n_cache, MB_MODE_INFO *best_mbmi, uint8_t *best_palette_color_map,
@@ -379,7 +379,7 @@
     uint8_t *skippable, int *beat_best_rd, PICK_MODE_CONTEXT *ctx,
     uint8_t *best_blk_skip, uint8_t *tx_type_map, uint8_t *color_map,
     int data_points, int discount_color_cost) {
-  int centroids[PALETTE_MAX_SIZE];
+  int16_t centroids[PALETTE_MAX_SIZE];
   const int max_itr = 50;
   int n = start_n;
   int top_color_winner = end_n;
@@ -435,16 +435,19 @@
   *step_size = AOMMAX(1, *max_n - *min_n);
 }
 
-static AOM_INLINE void fill_data_and_get_bounds(
-    const uint8_t *src, const int src_stride, const int rows, const int cols,
-    const int is_high_bitdepth, int *data, int *lower_bound, int *upper_bound) {
+static AOM_INLINE void fill_data_and_get_bounds(const uint8_t *src,
+                                                const int src_stride,
+                                                const int rows, const int cols,
+                                                const int is_high_bitdepth,
+                                                int16_t *data, int *lower_bound,
+                                                int *upper_bound) {
   if (is_high_bitdepth) {
     const uint16_t *src_ptr = CONVERT_TO_SHORTPTR(src);
     *lower_bound = *upper_bound = src_ptr[0];
     for (int r = 0; r < rows; ++r) {
       for (int c = 0; c < cols; ++c) {
         const int val = src_ptr[c];
-        data[c] = val;
+        data[c] = (int16_t)val;
         *lower_bound = AOMMIN(*lower_bound, val);
         *upper_bound = AOMMAX(*upper_bound, val);
       }
@@ -459,7 +462,7 @@
   for (int r = 0; r < rows; ++r) {
     for (int c = 0; c < cols; ++c) {
       const int val = src[c];
-      data[c] = val;
+      data[c] = (int16_t)val;
       *lower_bound = AOMMIN(*lower_bound, val);
       *upper_bound = AOMMAX(*upper_bound, val);
     }
@@ -487,7 +490,7 @@
 }
 
 static void find_top_colors(const int *const count_buf, int bit_depth,
-                            int n_colors, int *top_colors) {
+                            int n_colors, int16_t *top_colors) {
   // Top color array, serving as a priority queue if more than n_colors are
   // found.
   struct ColorCount top_color_counts[PALETTE_MAX_SIZE] = { { 0 } };
@@ -562,8 +565,8 @@
 
   uint8_t *const color_map = xd->plane[0].color_index_map;
   if (colors_threshold > 1 && colors_threshold <= 64) {
-    int *const data = x->palette_buffer->kmeans_data_buf;
-    int centroids[PALETTE_MAX_SIZE];
+    int16_t *const data = x->palette_buffer->kmeans_data_buf;
+    int16_t centroids[PALETTE_MAX_SIZE];
     int lower_bound, upper_bound;
     fill_data_and_get_bounds(src, src_stride, rows, cols, is_hbd, data,
                              &lower_bound, &upper_bound);
@@ -575,7 +578,7 @@
     const int n_cache = av1_get_palette_cache(xd, 0, color_cache);
 
     // Find the dominant colors, stored in top_colors[].
-    int top_colors[PALETTE_MAX_SIZE] = { 0 };
+    int16_t top_colors[PALETTE_MAX_SIZE] = { 0 };
     find_top_colors(count_buf, bit_depth, AOMMIN(colors, PALETTE_MAX_SIZE),
                     top_colors);
 
@@ -791,8 +794,8 @@
     const int max_itr = 50;
     int lb_u, ub_u, val_u;
     int lb_v, ub_v, val_v;
-    int *const data = x->palette_buffer->kmeans_data_buf;
-    int centroids[2 * PALETTE_MAX_SIZE];
+    int16_t *const data = x->palette_buffer->kmeans_data_buf;
+    int16_t centroids[2 * PALETTE_MAX_SIZE];
 
     uint16_t *src_u16 = CONVERT_TO_SHORTPTR(src_u);
     uint16_t *src_v16 = CONVERT_TO_SHORTPTR(src_v);
@@ -917,8 +920,8 @@
   int src_stride = x->plane[1].src.stride;
   const uint8_t *const src_u = x->plane[1].src.buf;
   const uint8_t *const src_v = x->plane[2].src.buf;
-  int *const data = x->palette_buffer->kmeans_data_buf;
-  int centroids[2 * PALETTE_MAX_SIZE];
+  int16_t *const data = x->palette_buffer->kmeans_data_buf;
+  int16_t centroids[2 * PALETTE_MAX_SIZE];
   uint8_t *const color_map = xd->plane[1].color_index_map;
   int r, c;
   const uint16_t *const src_u16 = CONVERT_TO_SHORTPTR(src_u);
diff --git a/av1/encoder/palette.h b/av1/encoder/palette.h
index 6f33f44..7da863a 100644
--- a/av1/encoder/palette.h
+++ b/av1/encoder/palette.h
@@ -28,10 +28,10 @@
 /*!\cond */
 #define AV1_K_MEANS_RENAME(func, dim) func##_dim##dim##_c
 
-void AV1_K_MEANS_RENAME(av1_k_means, 1)(const int *data, int *centroids,
+void AV1_K_MEANS_RENAME(av1_k_means, 1)(const int16_t *data, int16_t *centroids,
                                         uint8_t *indices, int n, int k,
                                         int max_itr);
-void AV1_K_MEANS_RENAME(av1_k_means, 2)(const int *data, int *centroids,
+void AV1_K_MEANS_RENAME(av1_k_means, 2)(const int16_t *data, int16_t *centroids,
                                         uint8_t *indices, int n, int k,
                                         int max_itr);
 /*!\endcond */
@@ -51,8 +51,9 @@
  *
  * \remark Returns nothing, but saves each data's cluster index in \a indices.
  */
-static INLINE void av1_calc_indices(const int *data, const int *centroids,
-                                    uint8_t *indices, int n, int k, int dim) {
+static INLINE void av1_calc_indices(const int16_t *data,
+                                    const int16_t *centroids, uint8_t *indices,
+                                    int n, int k, int dim) {
   assert(n > 0);
   assert(k > 0);
   if (dim == 1) {
@@ -84,7 +85,7 @@
  *
  * \attention The output centroids are rounded off to nearest integers.
  */
-static INLINE void av1_k_means(const int *data, int *centroids,
+static INLINE void av1_k_means(const int16_t *data, int16_t *centroids,
                                uint8_t *indices, int n, int k, int dim,
                                int max_itr) {
   assert(n > 0);
@@ -110,7 +111,7 @@
  * \attention The centroids should be rounded to integers before calling this
  * method.
  */
-int av1_remove_duplicates(int *centroids, int num_centroids);
+int av1_remove_duplicates(int16_t *centroids, int num_centroids);
 
 /*!\brief Checks what colors are in the color cache.
  *
diff --git a/av1/encoder/x86/av1_k_means_avx2.c b/av1/encoder/x86/av1_k_means_avx2.c
index 854b90c..a2db222 100644
--- a/av1/encoder/x86/av1_k_means_avx2.c
+++ b/av1/encoder/x86/av1_k_means_avx2.c
@@ -23,7 +23,55 @@
   return res;
 }
 
-void av1_calc_indices_dim1_avx2(const int *data, const int *centroids,
+void av1_calc_indices_dim1_avx2(const int16_t *data, const int16_t *centroids,
+                                uint8_t *indices, int64_t *total_dist, int n,
+                                int k) {
+  __m256i dist[PALETTE_MAX_SIZE];
+  const __m256i v_zero = _mm256_setzero_si256();
+  __m256i sum = _mm256_setzero_si256();
+
+  for (int i = 0; i < n; i += 16) {
+    __m256i ind = _mm256_loadu_si256((__m256i *)data);
+    for (int j = 0; j < k; j++) {
+      __m256i cent = _mm256_set1_epi16(centroids[j]);
+      __m256i d1 = _mm256_sub_epi16(ind, cent);
+      dist[j] = _mm256_abs_epi16(d1);
+    }
+
+    ind = _mm256_setzero_si256();
+    for (int j = 1; j < k; j++) {
+      __m256i cmp = _mm256_cmpgt_epi16(dist[0], dist[j]);
+      dist[0] = _mm256_min_epi16(dist[0], dist[j]);
+      __m256i ind1 = _mm256_set1_epi16(j);
+      ind = _mm256_or_si256(_mm256_andnot_si256(cmp, ind),
+                            _mm256_and_si256(cmp, ind1));
+    }
+
+    __m256i p1 = _mm256_packus_epi16(ind, v_zero);
+    __m256i px = _mm256_permute4x64_epi64(p1, 0x58);
+    __m128i d1 = _mm256_extracti128_si256(px, 0);
+
+    _mm_storeu_si128((__m128i *)indices, d1);
+
+    if (total_dist) {
+      // Square, convert to 32 bit and add together.
+      dist[0] = _mm256_madd_epi16(dist[0], dist[0]);
+      // Convert to 64 bit and add to sum.
+      const __m256i dist1 = _mm256_unpacklo_epi32(dist[0], v_zero);
+      const __m256i dist2 = _mm256_unpackhi_epi32(dist[0], v_zero);
+      sum = _mm256_add_epi64(sum, dist1);
+      sum = _mm256_add_epi64(sum, dist2);
+    }
+
+    indices += 16;
+    data += 16;
+  }
+  if (total_dist) {
+    *total_dist = k_means_horizontal_sum_avx2(sum);
+  }
+}
+
+void av1_calc_indices_dim2_avx2(const int16_t *data, const int16_t *centroids,
                                 uint8_t *indices, int64_t *total_dist, int n,
                                 int k) {
   __m256i dist[PALETTE_MAX_SIZE];
@@ -33,77 +81,18 @@
   for (int i = 0; i < n; i += 8) {
     __m256i ind = _mm256_loadu_si256((__m256i *)data);
     for (int j = 0; j < k; j++) {
-      __m256i cent = _mm256_set1_epi32(centroids[j]);
-      __m256i d1 = _mm256_sub_epi32(ind, cent);
-      dist[j] = _mm256_abs_epi32(d1);
+      const int16_t cx = centroids[2 * j], cy = centroids[2 * j + 1];
+      const __m256i cent = _mm256_set_epi16(cy, cx, cy, cx, cy, cx, cy, cx, cy,
+                                            cx, cy, cx, cy, cx, cy, cx);
+      const __m256i d1 = _mm256_sub_epi16(ind, cent);
+      dist[j] = _mm256_madd_epi16(d1, d1);
     }
 
     ind = _mm256_setzero_si256();
     for (int j = 1; j < k; j++) {
       __m256i cmp = _mm256_cmpgt_epi32(dist[0], dist[j]);
-      __m256i dist1 = _mm256_andnot_si256(cmp, dist[0]);
-      __m256i dist2 = _mm256_and_si256(cmp, dist[j]);
-      dist[0] = _mm256_or_si256(dist1, dist2);
-      __m256i ind1 = _mm256_set1_epi32(j);
-      ind = _mm256_or_si256(_mm256_andnot_si256(cmp, ind),
-                            _mm256_and_si256(cmp, ind1));
-    }
-
-    __m256i p1 = _mm256_packus_epi32(ind, v_zero);
-    __m256i px = _mm256_permute4x64_epi64(p1, 0x58);
-    __m256i p2 = _mm256_packus_epi16(px, v_zero);
-    __m128i d1 = _mm256_extracti128_si256(p2, 0);
-
-    _mm_storel_epi64((__m128i *)indices, d1);
-
-    if (total_dist) {
-      // Square, convert to 64 bit and add to sum.
-      dist[0] = _mm256_mullo_epi32(dist[0], dist[0]);
-      const __m256i dist1 = _mm256_unpacklo_epi32(dist[0], v_zero);
-      const __m256i dist2 = _mm256_unpackhi_epi32(dist[0], v_zero);
-      sum = _mm256_add_epi64(sum, dist1);
-      sum = _mm256_add_epi64(sum, dist2);
-    }
-
-    indices += 8;
-    data += 8;
-  }
-  if (total_dist) {
-    *total_dist = k_means_horizontal_sum_avx2(sum);
-  }
-}
-
-void av1_calc_indices_dim2_avx2(const int *data, const int *centroids,
-                                uint8_t *indices, int64_t *total_dist, int n,
-                                int k) {
-  __m256i dist[PALETTE_MAX_SIZE];
-  const __m256i v_zero = _mm256_setzero_si256();
-  const __m256i v_permute = _mm256_setr_epi32(0, 1, 4, 5, 2, 3, 6, 7);
-  __m256i sum = _mm256_setzero_si256();
-
-  for (int i = 0; i < n; i += 8) {
-    __m256i ind1 = _mm256_loadu_si256((__m256i *)data);
-    __m256i ind2 = _mm256_loadu_si256((__m256i *)(data + 8));
-    for (int j = 0; j < k; j++) {
-      __m128i cent0 = _mm_loadl_epi64((__m128i const *)&centroids[2 * j]);
-      __m256i cent1 = _mm256_inserti128_si256(v_zero, cent0, 0);
-      cent1 = _mm256_inserti128_si256(cent1, cent0, 1);
-      __m256i cent = _mm256_unpacklo_epi64(cent1, cent1);
-      __m256i d1 = _mm256_sub_epi32(ind1, cent);
-      __m256i d2 = _mm256_sub_epi32(ind2, cent);
-      __m256i d3 = _mm256_mullo_epi32(d1, d1);
-      __m256i d4 = _mm256_mullo_epi32(d2, d2);
-      __m256i d5 = _mm256_hadd_epi32(d3, d4);
-      dist[j] = _mm256_permutevar8x32_epi32(d5, v_permute);
-    }
-
-    __m256i ind = _mm256_setzero_si256();
-    for (int j = 1; j < k; j++) {
-      __m256i cmp = _mm256_cmpgt_epi32(dist[0], dist[j]);
-      __m256i dist1 = _mm256_andnot_si256(cmp, dist[0]);
-      __m256i dist2 = _mm256_and_si256(cmp, dist[j]);
-      dist[0] = _mm256_or_si256(dist1, dist2);
-      ind1 = _mm256_set1_epi32(j);
+      dist[0] = _mm256_min_epi32(dist[0], dist[j]);
+      const __m256i ind1 = _mm256_set1_epi32(j);
       ind = _mm256_or_si256(_mm256_andnot_si256(cmp, ind),
                             _mm256_and_si256(cmp, ind1));
     }
diff --git a/av1/encoder/x86/av1_k_means_sse2.c b/av1/encoder/x86/av1_k_means_sse2.c
index d2c7796..a284fa9 100644
--- a/av1/encoder/x86/av1_k_means_sse2.c
+++ b/av1/encoder/x86/av1_k_means_sse2.c
@@ -22,68 +22,49 @@
   return res;
 }
 
-static __m128i absolute_diff_epi32(__m128i a, __m128i b) {
-  const __m128i diff1 = _mm_sub_epi32(a, b);
-  const __m128i diff2 = _mm_sub_epi32(b, a);
-  const __m128i cmp = _mm_cmpgt_epi32(diff1, diff2);
-  const __m128i masked1 = _mm_and_si128(cmp, diff1);
-  const __m128i masked2 = _mm_andnot_si128(cmp, diff2);
-  return _mm_or_si128(masked1, masked2);
-}
-
-void av1_calc_indices_dim1_sse2(const int *data, const int *centroids,
+void av1_calc_indices_dim1_sse2(const int16_t *data, const int16_t *centroids,
                                 uint8_t *indices, int64_t *total_dist, int n,
                                 int k) {
   const __m128i v_zero = _mm_setzero_si128();
-  int l = 1;
   __m128i dist[PALETTE_MAX_SIZE];
-  __m128i ind[2];
   __m128i sum = _mm_setzero_si128();
 
-  for (int i = 0; i < n; i += 4) {
-    l = (l == 0) ? 1 : 0;
-    ind[l] = _mm_loadu_si128((__m128i *)data);
+  for (int i = 0; i < n; i += 8) {
+    __m128i in = _mm_loadu_si128((__m128i *)data);
     for (int j = 0; j < k; j++) {
-      __m128i cent = _mm_set1_epi32(centroids[j]);
-      dist[j] = absolute_diff_epi32(ind[l], cent);
+      __m128i cent = _mm_set1_epi16(centroids[j]);
+      __m128i d1 = _mm_sub_epi16(in, cent);
+      __m128i d2 = _mm_sub_epi16(cent, in);
+      dist[j] = _mm_max_epi16(d1, d2);
     }
 
-    ind[l] = _mm_setzero_si128();
+    __m128i ind = _mm_setzero_si128();
     for (int j = 1; j < k; j++) {
-      __m128i cmp = _mm_cmpgt_epi32(dist[0], dist[j]);
-      __m128i dist1 = _mm_andnot_si128(cmp, dist[0]);
-      __m128i dist2 = _mm_and_si128(cmp, dist[j]);
-      dist[0] = _mm_or_si128(dist1, dist2);
-      __m128i ind1 = _mm_set1_epi32(j);
-      ind[l] =
-          _mm_or_si128(_mm_andnot_si128(cmp, ind[l]), _mm_and_si128(cmp, ind1));
+      __m128i cmp = _mm_cmpgt_epi16(dist[0], dist[j]);
+      dist[0] = _mm_min_epi16(dist[0], dist[j]);
+      __m128i ind1 = _mm_set1_epi16(j);
+      ind = _mm_or_si128(_mm_andnot_si128(cmp, ind), _mm_and_si128(cmp, ind1));
     }
-    ind[l] = _mm_packus_epi16(ind[l], v_zero);
     if (total_dist) {
-      // Square and convert to 32 bit.
-      const __m128i d1 = _mm_packs_epi32(dist[0], v_zero);
-      const __m128i d2 = _mm_mullo_epi16(d1, d1);
-      const __m128i d3 = _mm_mulhi_epi16(d1, d1);
-      dist[0] = _mm_unpacklo_epi16(d2, d3);
+      // Square, convert to 32 bit and add together.
+      dist[0] = _mm_madd_epi16(dist[0], dist[0]);
       // Convert to 64 bit and add to sum.
       const __m128i dist1 = _mm_unpacklo_epi32(dist[0], v_zero);
       const __m128i dist2 = _mm_unpackhi_epi32(dist[0], v_zero);
       sum = _mm_add_epi64(sum, dist1);
       sum = _mm_add_epi64(sum, dist2);
     }
-    if (l == 1) {
-      __m128i p2 = _mm_packus_epi16(_mm_unpacklo_epi64(ind[0], ind[1]), v_zero);
-      _mm_storel_epi64((__m128i *)indices, p2);
-      indices += 8;
-    }
-    data += 4;
+    __m128i p2 = _mm_packus_epi16(ind, v_zero);
+    _mm_storel_epi64((__m128i *)indices, p2);
+    indices += 8;
+    data += 8;
   }
   if (total_dist) {
     *total_dist = k_means_horizontal_sum_sse2(sum);
   }
 }
 
-void av1_calc_indices_dim2_sse2(const int *data, const int *centroids,
+void av1_calc_indices_dim2_sse2(const int16_t *data, const int16_t *centroids,
                                 uint8_t *indices, int64_t *total_dist, int n,
                                 int k) {
   const __m128i v_zero = _mm_setzero_si128();
@@ -95,19 +76,11 @@
   for (int i = 0; i < n; i += 4) {
     l = (l == 0) ? 1 : 0;
     __m128i ind1 = _mm_loadu_si128((__m128i *)data);
-    __m128i ind2 = _mm_loadu_si128((__m128i *)(data + 4));
-    __m128i indl = _mm_unpacklo_epi32(ind1, ind2);
-    __m128i indh = _mm_unpackhi_epi32(ind1, ind2);
-    ind1 = _mm_unpacklo_epi32(indl, indh);
-    ind2 = _mm_unpackhi_epi32(indl, indh);
     for (int j = 0; j < k; j++) {
-      __m128i cent0 = _mm_set1_epi32(centroids[2 * j]);
-      __m128i cent1 = _mm_set1_epi32(centroids[2 * j + 1]);
-      __m128i d1 = absolute_diff_epi32(ind1, cent0);
-      __m128i d2 = absolute_diff_epi32(ind2, cent1);
-      __m128i d3 = _mm_madd_epi16(d1, d1);
-      __m128i d4 = _mm_madd_epi16(d2, d2);
-      dist[j] = _mm_add_epi32(d3, d4);
+      const int16_t cx = centroids[2 * j], cy = centroids[2 * j + 1];
+      const __m128i cent = _mm_set_epi16(cy, cx, cy, cx, cy, cx, cy, cx);
+      const __m128i d1 = _mm_sub_epi16(ind1, cent);
+      dist[j] = _mm_madd_epi16(d1, d1);
     }
 
     ind[l] = _mm_setzero_si128();
diff --git a/test/av1_k_means_test.cc b/test/av1_k_means_test.cc
index 5b6c22e..221dd10 100644
--- a/test/av1_k_means_test.cc
+++ b/test/av1_k_means_test.cc
@@ -28,12 +28,12 @@
 #include "third_party/googletest/src/googletest/include/gtest/gtest.h"
 
 namespace AV1Kmeans {
-typedef void (*av1_calc_indices_dim1_func)(const int *data,
-                                           const int *centroids,
+typedef void (*av1_calc_indices_dim1_func)(const int16_t *data,
+                                           const int16_t *centroids,
                                            uint8_t *indices,
                                            int64_t *total_dist, int n, int k);
-typedef void (*av1_calc_indices_dim2_func)(const int *data,
-                                           const int *centroids,
+typedef void (*av1_calc_indices_dim2_func)(const int16_t *data,
+                                           const int16_t *centroids,
                                            uint8_t *indices,
                                            int64_t *total_dist, int n, int k);
 
@@ -68,8 +68,8 @@
   }
 
   libaom_test::ACMRandom rnd_;
-  int data_[4096];
-  int centroids_[8];
+  int16_t data_[4096];
+  int16_t centroids_[8];
   uint8_t indices1_[4096];
   uint8_t indices2_[4096];
 };
@@ -178,8 +178,8 @@
   }
 
   libaom_test::ACMRandom rnd_;
-  int data_[4096 * 2];
-  int centroids_[8 * 2];
+  int16_t data_[4096 * 2];
+  int16_t centroids_[8 * 2];
   uint8_t indices1_[4096];
   uint8_t indices2_[4096];
 };