Fix several bugs in K-means.

Mostly, a test was added to make sure total_dist is the same on all
platforms

Some introduced by the recent 8cbaea62af09e92a81853e1787c9912b3ac6999c:
- dist instead of *dist: the C version was not part of my SSIM check
- the horizontal_sum_avx2 function has been re-written to be pure SSE2
(to avoid non-aligned memory problem)
- int64_t conversions were wrong (though the results right because
only a comparison is needed and the results were shifted by 32 bits)

Some much older:
- av1_calc_indices_dim2_sse2 was actually wrong as it was doing
_mm_madd_epi16 on 32 bit which only works for positive numbers, and
is otherwise creating a difference of 1 (-1 * -1).
The solution is to take the absolute value for now. The proper
solution is to switch the pipeline to 16 bit as this is what it
assumes.

Change-Id: Iaeeefa9f41d5262a622f6b64b9221696a2335c8d
diff --git a/av1/encoder/k_means_template.h b/av1/encoder/k_means_template.h
index fe3e9ef..9560da5 100644
--- a/av1/encoder/k_means_template.h
+++ b/av1/encoder/k_means_template.h
@@ -50,7 +50,7 @@
       }
     }
     if (dist) {
-      dist += min_dist;
+      *dist += min_dist;
     }
   }
 }
diff --git a/av1/encoder/x86/av1_k_means_avx2.c b/av1/encoder/x86/av1_k_means_avx2.c
index 2177358..2745ac1 100644
--- a/av1/encoder/x86/av1_k_means_avx2.c
+++ b/av1/encoder/x86/av1_k_means_avx2.c
@@ -14,9 +14,13 @@
 #include "aom_dsp/x86/synonyms.h"
 
 static int64_t k_means_horizontal_sum_avx2(__m256i a) {
-  int64_t dists[4];
-  _mm256_store_si256((__m256i *)dists, a);
-  return a[0] + a[1] + a[2] + a[3];
+  const __m128i low = _mm256_castsi256_si128(a);
+  const __m128i high = _mm256_extracti128_si256(a, 1);
+  const __m128i sum = _mm_add_epi64(low, high);
+  const __m128i sum_high = _mm_unpackhi_epi64(sum, sum);
+  int64_t res;
+  _mm_storel_epi64((__m128i *)&res, _mm_add_epi64(sum, sum_high));
+  return res;
 }
 
 void av1_calc_indices_dim1_avx2(const int *data, const int *centroids,
@@ -54,8 +58,8 @@
 
     if (total_dist) {
       // Convert to 64 bit and add to sum.
-      const __m256i dist1 = _mm256_unpacklo_epi32(v_zero, dist[0]);
-      const __m256i dist2 = _mm256_unpackhi_epi32(v_zero, dist[0]);
+      const __m256i dist1 = _mm256_unpacklo_epi32(dist[0], v_zero);
+      const __m256i dist2 = _mm256_unpackhi_epi32(dist[0], v_zero);
       sum = _mm256_add_epi64(sum, dist1);
       sum = _mm256_add_epi64(sum, dist2);
     }
@@ -112,8 +116,8 @@
 
     if (total_dist) {
       // Convert to 64 bit and add to sum.
-      const __m256i dist1 = _mm256_unpacklo_epi32(v_zero, dist[0]);
-      const __m256i dist2 = _mm256_unpackhi_epi32(v_zero, dist[0]);
+      const __m256i dist1 = _mm256_unpacklo_epi32(dist[0], v_zero);
+      const __m256i dist2 = _mm256_unpackhi_epi32(dist[0], v_zero);
       sum = _mm256_add_epi64(sum, dist1);
       sum = _mm256_add_epi64(sum, dist2);
     }
diff --git a/av1/encoder/x86/av1_k_means_sse2.c b/av1/encoder/x86/av1_k_means_sse2.c
index df823a0..2c12346 100644
--- a/av1/encoder/x86/av1_k_means_sse2.c
+++ b/av1/encoder/x86/av1_k_means_sse2.c
@@ -18,7 +18,7 @@
   const __m128i sum1 = _mm_unpackhi_epi64(a, a);
   const __m128i sum2 = _mm_add_epi64(a, sum1);
   int64_t res;
-  _mm_storel_epi64((__m128i_u *)(&res), sum2);
+  _mm_storel_epi64((__m128i *)&res, sum2);
   return res;
 }
 
@@ -56,8 +56,8 @@
     ind[l] = _mm_packus_epi16(ind[l], v_zero);
     if (total_dist) {
       // Convert to 64 bit and add to sum.
-      const __m128i dist1 = _mm_unpacklo_epi32(v_zero, dist[0]);
-      const __m128i dist2 = _mm_unpackhi_epi32(v_zero, dist[0]);
+      const __m128i dist1 = _mm_unpacklo_epi32(dist[0], v_zero);
+      const __m128i dist2 = _mm_unpackhi_epi32(dist[0], v_zero);
       sum = _mm_add_epi64(sum, dist1);
       sum = _mm_add_epi64(sum, dist2);
     }
@@ -73,6 +73,15 @@
   }
 }
 
+static __m128i absolute_diff_epi32(__m128i a, __m128i b) {
+  const __m128i diff1 = _mm_sub_epi32(a, b);
+  const __m128i diff2 = _mm_sub_epi32(b, a);
+  const __m128i cmp = _mm_cmpgt_epi32(diff1, diff2);
+  const __m128i masked1 = _mm_and_si128(cmp, diff1);
+  const __m128i masked2 = _mm_andnot_si128(cmp, diff2);
+  return _mm_or_si128(masked1, masked2);
+}
+
 void av1_calc_indices_dim2_sse2(const int *data, const int *centroids,
                                 uint8_t *indices, int64_t *total_dist, int n,
                                 int k) {
@@ -93,8 +102,8 @@
     for (int j = 0; j < k; j++) {
       __m128i cent0 = _mm_set1_epi32(centroids[2 * j]);
       __m128i cent1 = _mm_set1_epi32(centroids[2 * j + 1]);
-      __m128i d1 = _mm_sub_epi32(ind1, cent0);
-      __m128i d2 = _mm_sub_epi32(ind2, cent1);
+      __m128i d1 = absolute_diff_epi32(ind1, cent0);
+      __m128i d2 = absolute_diff_epi32(ind2, cent1);
       __m128i d3 = _mm_madd_epi16(d1, d1);
       __m128i d4 = _mm_madd_epi16(d2, d2);
       dist[j] = _mm_add_epi32(d3, d4);
@@ -113,8 +122,8 @@
     ind[l] = _mm_packus_epi16(ind[l], v_zero);
     if (total_dist) {
       // Convert to 64 bit and add to sum.
-      const __m128i dist1 = _mm_unpacklo_epi32(v_zero, dist[0]);
-      const __m128i dist2 = _mm_unpackhi_epi32(v_zero, dist[0]);
+      const __m128i dist1 = _mm_unpacklo_epi32(dist[0], v_zero);
+      const __m128i dist2 = _mm_unpackhi_epi32(dist[0], v_zero);
       sum = _mm_add_epi64(sum, dist1);
       sum = _mm_add_epi64(sum, dist2);
     }
diff --git a/test/av1_k_means_test.cc b/test/av1_k_means_test.cc
index 2ef837c..5b6c22e 100644
--- a/test/av1_k_means_test.cc
+++ b/test/av1_k_means_test.cc
@@ -94,10 +94,11 @@
   const int w = block_size_wide[bsize];
   const int h = block_size_high[bsize];
   const int n = w * h;
-  av1_calc_indices_dim1_c(data_, centroids_, indices1_, /*total_dist=*/nullptr,
-                          n, k);
-  test_impl(data_, centroids_, indices2_, /*total_dist=*/nullptr, n, k);
+  int64_t total_dist_dim1, total_dist_impl;
+  av1_calc_indices_dim1_c(data_, centroids_, indices1_, &total_dist_dim1, n, k);
+  test_impl(data_, centroids_, indices2_, &total_dist_impl, n, k);
 
+  ASSERT_EQ(total_dist_dim1, total_dist_impl);
   ASSERT_EQ(CheckResult(n), true)
       << " block " << bsize << " index " << n << " Centroids " << k;
 }
@@ -203,10 +204,11 @@
   const int w = block_size_wide[bsize];
   const int h = block_size_high[bsize];
   const int n = w * h;
-  av1_calc_indices_dim2_c(data_, centroids_, indices1_, /*total_dist=*/nullptr,
-                          n, k);
-  test_impl(data_, centroids_, indices2_, /*total_dist=*/nullptr, n, k);
+  int64_t total_dist_dim2, total_dist_impl;
+  av1_calc_indices_dim2_c(data_, centroids_, indices1_, &total_dist_dim2, n, k);
+  test_impl(data_, centroids_, indices2_, &total_dist_impl, n, k);
 
+  ASSERT_EQ(total_dist_dim2, total_dist_impl);
   ASSERT_EQ(CheckResult(n), true)
       << " block " << bsize << " index " << n << " Centroids " << k;
 }
@@ -225,7 +227,7 @@
     aom_usec_timer_start(&timer);
     av1_calc_indices_dim2_func func = funcs[i];
     for (int j = 0; j < num_loops; ++j) {
-      func(data_, centroids_, indices1_, nullptr, n, k);
+      func(data_, centroids_, indices1_, /*total_dist=*/nullptr, n, k);
     }
     aom_usec_timer_mark(&timer);
     double time = static_cast<double>(aom_usec_timer_elapsed(&timer));