K-means: Compute calc_total_dist inside calc_indices.

Bug: b/217282899

On rtc_screen, speed 0, 50 runs, speed-ups are in %:

screen_recording_crd.1920_1080.y4m	0.461
screenshare_buganizer.1900_1306.y4m	0.376
screenshare_colorslides.1820_1320.y4m	0.601
screenshare_slidechanges.1850_1110.y4m	0.410
screenshare_youtube.1680_1178.y4m	0.338
slides_webplot.1920_1080.y4m	0.298
sc_web_browsing720p.y4m	0.901
screen_crd_colwinscroll.1920_1128.y4m	0.082
{OVERALL}	0.433

At speed 10:

screen_recording_crd.1920_1080.y4m	0.334
screenshare_buganizer.1900_1306.y4m	0.774
screenshare_colorslides.1820_1320.y4m	0.678
screenshare_slidechanges.1850_1110.y4m	0.791
screenshare_youtube.1680_1178.y4m	0.170
slides_webplot.1920_1080.y4m	0.640
sc_web_browsing720p.y4m	0.914
screen_crd_colwinscroll.1920_1128.y4m	0.066
{OVERALL}	0.546

Change-Id: Ic2f48e586eae2df96f1e3b0e11b76332f7b76ef9
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index 5a05548..ca299aa 100644
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -412,10 +412,10 @@
 
   add_proto qw/void av1_quantize_b/, "const tran_low_t *coeff_ptr, intptr_t n_coeffs, const int16_t *zbin_ptr, const int16_t *round_ptr, const int16_t *quant_ptr, const int16_t *quant_shift_ptr, tran_low_t *qcoeff_ptr, tran_low_t *dqcoeff_ptr, const int16_t *dequant_ptr, uint16_t *eob_ptr, const int16_t *scan, const int16_t *iscan, const qm_val_t * qm_ptr, const qm_val_t * iqm_ptr, int log_scale";
 
-  add_proto qw/void av1_calc_indices_dim1/, "const int *data, const int *centroids, uint8_t *indices, int n, int k";
+  add_proto qw/void av1_calc_indices_dim1/, "const int *data, const int *centroids, uint8_t *indices, int64_t *total_dist, int n, int k";
   specialize qw/av1_calc_indices_dim1 sse2 avx2/;
 
-  add_proto qw/void av1_calc_indices_dim2/, "const int *data, const int *centroids, uint8_t *indices, int n, int k";
+  add_proto qw/void av1_calc_indices_dim2/, "const int *data, const int *centroids, uint8_t *indices, int64_t *total_dist, int n, int k";
   specialize qw/av1_calc_indices_dim2 sse2 avx2/;
 
   # ENCODEMB INVOKE
diff --git a/av1/encoder/k_means_template.h b/av1/encoder/k_means_template.h
index 87677a1..fe3e9ef 100644
--- a/av1/encoder/k_means_template.h
+++ b/av1/encoder/k_means_template.h
@@ -34,7 +34,10 @@
 }
 
 void RENAME(av1_calc_indices)(const int *data, const int *centroids,
-                              uint8_t *indices, int n, int k) {
+                              uint8_t *indices, int64_t *dist, int n, int k) {
+  if (dist) {
+    *dist = 0;
+  }
   for (int i = 0; i < n; ++i) {
     int min_dist = RENAME(calc_dist)(data + i * AV1_K_MEANS_DIM, centroids);
     indices[i] = 0;
@@ -46,6 +49,9 @@
         indices[i] = j;
       }
     }
+    if (dist) {
+      dist += min_dist;
+    }
   }
 }
 
@@ -80,17 +86,6 @@
   }
 }
 
-static int64_t RENAME(calc_total_dist)(const int *data, const int *centroids,
-                                       const uint8_t *indices, int n, int k) {
-  int64_t dist = 0;
-  (void)k;
-  for (int i = 0; i < n; ++i) {
-    dist += RENAME(calc_dist)(data + i * AV1_K_MEANS_DIM,
-                              centroids + indices[i] * AV1_K_MEANS_DIM);
-  }
-  return dist;
-}
-
 void RENAME(av1_k_means)(const int *data, int *centroids, uint8_t *indices,
                          int n, int k, int max_itr) {
   int centroids_tmp[AV1_K_MEANS_DIM * PALETTE_MAX_SIZE];
@@ -103,11 +98,10 @@
   assert(n <= MAX_PALETTE_BLOCK_WIDTH * MAX_PALETTE_BLOCK_HEIGHT);
 
 #if AV1_K_MEANS_DIM - 2
-  av1_calc_indices_dim1(data, centroids, indices, n, k);
+  av1_calc_indices_dim1(data, centroids, indices, &this_dist, n, k);
 #else
-  av1_calc_indices_dim2(data, centroids, indices, n, k);
+  av1_calc_indices_dim2(data, centroids, indices, &this_dist, n, k);
 #endif
-  this_dist = RENAME(calc_total_dist)(data, centroids, indices, n, k);
 
   for (i = 0; i < max_itr; ++i) {
     const int64_t prev_dist = this_dist;
@@ -116,12 +110,12 @@
 
     RENAME(calc_centroids)(data, meta_centroids[l], meta_indices[prev_l], n, k);
 #if AV1_K_MEANS_DIM - 2
-    av1_calc_indices_dim1(data, meta_centroids[l], meta_indices[l], n, k);
+    av1_calc_indices_dim1(data, meta_centroids[l], meta_indices[l], &this_dist,
+                          n, k);
 #else
-    av1_calc_indices_dim2(data, meta_centroids[l], meta_indices[l], n, k);
+    av1_calc_indices_dim2(data, meta_centroids[l], meta_indices[l], &this_dist,
+                          n, k);
 #endif
-    this_dist =
-        RENAME(calc_total_dist)(data, meta_centroids[l], meta_indices[l], n, k);
 
     if (this_dist > prev_dist) {
       best_l = prev_l;
diff --git a/av1/encoder/palette.h b/av1/encoder/palette.h
index 34d2ddc..6f33f44 100644
--- a/av1/encoder/palette.h
+++ b/av1/encoder/palette.h
@@ -56,9 +56,9 @@
   assert(n > 0);
   assert(k > 0);
   if (dim == 1) {
-    av1_calc_indices_dim1(data, centroids, indices, n, k);
+    av1_calc_indices_dim1(data, centroids, indices, /*total_dist=*/NULL, n, k);
   } else if (dim == 2) {
-    av1_calc_indices_dim2(data, centroids, indices, n, k);
+    av1_calc_indices_dim2(data, centroids, indices, /*total_dist=*/NULL, n, k);
   } else {
     assert(0 && "Untemplated k means dimension");
   }
diff --git a/av1/encoder/x86/av1_k_means_avx2.c b/av1/encoder/x86/av1_k_means_avx2.c
index 759f515..2177358 100644
--- a/av1/encoder/x86/av1_k_means_avx2.c
+++ b/av1/encoder/x86/av1_k_means_avx2.c
@@ -13,10 +13,18 @@
 #include "config/aom_dsp_rtcd.h"
 #include "aom_dsp/x86/synonyms.h"
 
+static int64_t k_means_horizontal_sum_avx2(__m256i a) {
+  int64_t dists[4];
+  _mm256_store_si256((__m256i *)dists, a);
+  return a[0] + a[1] + a[2] + a[3];
+}
+
 void av1_calc_indices_dim1_avx2(const int *data, const int *centroids,
-                                uint8_t *indices, int n, int k) {
+                                uint8_t *indices, int64_t *total_dist, int n,
+                                int k) {
   __m256i dist[PALETTE_MAX_SIZE];
   const __m256i v_zero = _mm256_setzero_si256();
+  __m256i sum = _mm256_setzero_si256();
 
   for (int i = 0; i < n; i += 8) {
     __m256i ind = _mm256_loadu_si256((__m256i *)data);
@@ -44,16 +52,29 @@
 
     _mm_storel_epi64((__m128i *)indices, d1);
 
+    if (total_dist) {
+      // Convert to 64 bit and add to sum.
+      const __m256i dist1 = _mm256_unpacklo_epi32(v_zero, dist[0]);
+      const __m256i dist2 = _mm256_unpackhi_epi32(v_zero, dist[0]);
+      sum = _mm256_add_epi64(sum, dist1);
+      sum = _mm256_add_epi64(sum, dist2);
+    }
+
     indices += 8;
     data += 8;
   }
+  if (total_dist) {
+    *total_dist = k_means_horizontal_sum_avx2(sum);
+  }
 }
 
 void av1_calc_indices_dim2_avx2(const int *data, const int *centroids,
-                                uint8_t *indices, int n, int k) {
+                                uint8_t *indices, int64_t *total_dist, int n,
+                                int k) {
   __m256i dist[PALETTE_MAX_SIZE];
   const __m256i v_zero = _mm256_setzero_si256();
   const __m256i v_permute = _mm256_setr_epi32(0, 1, 4, 5, 2, 3, 6, 7);
+  __m256i sum = _mm256_setzero_si256();
 
   for (int i = 0; i < n; i += 8) {
     __m256i ind1 = _mm256_loadu_si256((__m256i *)data);
@@ -89,7 +110,18 @@
 
     _mm_storel_epi64((__m128i *)indices, d1);
 
+    if (total_dist) {
+      // Convert to 64 bit and add to sum.
+      const __m256i dist1 = _mm256_unpacklo_epi32(v_zero, dist[0]);
+      const __m256i dist2 = _mm256_unpackhi_epi32(v_zero, dist[0]);
+      sum = _mm256_add_epi64(sum, dist1);
+      sum = _mm256_add_epi64(sum, dist2);
+    }
+
     indices += 8;
     data += 16;
   }
+  if (total_dist) {
+    *total_dist = k_means_horizontal_sum_avx2(sum);
+  }
 }
diff --git a/av1/encoder/x86/av1_k_means_sse2.c b/av1/encoder/x86/av1_k_means_sse2.c
index f03c459..df823a0 100644
--- a/av1/encoder/x86/av1_k_means_sse2.c
+++ b/av1/encoder/x86/av1_k_means_sse2.c
@@ -14,12 +14,22 @@
 #include "config/aom_dsp_rtcd.h"
 #include "aom_dsp/x86/synonyms.h"
 
+static int64_t k_means_horizontal_sum_sse2(__m128i a) {
+  const __m128i sum1 = _mm_unpackhi_epi64(a, a);
+  const __m128i sum2 = _mm_add_epi64(a, sum1);
+  int64_t res;
+  _mm_storel_epi64((__m128i_u *)(&res), sum2);
+  return res;
+}
+
 void av1_calc_indices_dim1_sse2(const int *data, const int *centroids,
-                                uint8_t *indices, int n, int k) {
+                                uint8_t *indices, int64_t *total_dist, int n,
+                                int k) {
   const __m128i v_zero = _mm_setzero_si128();
   int l = 1;
   __m128i dist[PALETTE_MAX_SIZE];
   __m128i ind[2];
+  __m128i sum = _mm_setzero_si128();
 
   for (int i = 0; i < n; i += 4) {
     l = (l == 0) ? 1 : 0;
@@ -44,6 +54,13 @@
           _mm_or_si128(_mm_andnot_si128(cmp, ind[l]), _mm_and_si128(cmp, ind1));
     }
     ind[l] = _mm_packus_epi16(ind[l], v_zero);
+    if (total_dist) {
+      // Convert to 64 bit and add to sum.
+      const __m128i dist1 = _mm_unpacklo_epi32(v_zero, dist[0]);
+      const __m128i dist2 = _mm_unpackhi_epi32(v_zero, dist[0]);
+      sum = _mm_add_epi64(sum, dist1);
+      sum = _mm_add_epi64(sum, dist2);
+    }
     if (l == 1) {
       __m128i p2 = _mm_packus_epi16(_mm_unpacklo_epi64(ind[0], ind[1]), v_zero);
       _mm_storel_epi64((__m128i *)indices, p2);
@@ -51,14 +68,19 @@
     }
     data += 4;
   }
+  if (total_dist) {
+    *total_dist = k_means_horizontal_sum_sse2(sum);
+  }
 }
 
 void av1_calc_indices_dim2_sse2(const int *data, const int *centroids,
-                                uint8_t *indices, int n, int k) {
+                                uint8_t *indices, int64_t *total_dist, int n,
+                                int k) {
   const __m128i v_zero = _mm_setzero_si128();
   int l = 1;
   __m128i dist[PALETTE_MAX_SIZE];
   __m128i ind[2];
+  __m128i sum = _mm_setzero_si128();
 
   for (int i = 0; i < n; i += 4) {
     l = (l == 0) ? 1 : 0;
@@ -89,6 +111,13 @@
           _mm_or_si128(_mm_andnot_si128(cmp, ind[l]), _mm_and_si128(cmp, ind1));
     }
     ind[l] = _mm_packus_epi16(ind[l], v_zero);
+    if (total_dist) {
+      // Convert to 64 bit and add to sum.
+      const __m128i dist1 = _mm_unpacklo_epi32(v_zero, dist[0]);
+      const __m128i dist2 = _mm_unpackhi_epi32(v_zero, dist[0]);
+      sum = _mm_add_epi64(sum, dist1);
+      sum = _mm_add_epi64(sum, dist2);
+    }
     if (l == 1) {
       __m128i p2 = _mm_packus_epi16(_mm_unpacklo_epi64(ind[0], ind[1]), v_zero);
       _mm_storel_epi64((__m128i *)indices, p2);
@@ -96,4 +125,7 @@
     }
     data += 8;
   }
+  if (total_dist) {
+    *total_dist = k_means_horizontal_sum_sse2(sum);
+  }
 }
diff --git a/test/av1_k_means_test.cc b/test/av1_k_means_test.cc
index 8a85323..2ef837c 100644
--- a/test/av1_k_means_test.cc
+++ b/test/av1_k_means_test.cc
@@ -30,10 +30,12 @@
 namespace AV1Kmeans {
 typedef void (*av1_calc_indices_dim1_func)(const int *data,
                                            const int *centroids,
-                                           uint8_t *indices, int n, int k);
+                                           uint8_t *indices,
+                                           int64_t *total_dist, int n, int k);
 typedef void (*av1_calc_indices_dim2_func)(const int *data,
                                            const int *centroids,
-                                           uint8_t *indices, int n, int k);
+                                           uint8_t *indices,
+                                           int64_t *total_dist, int n, int k);
 
 typedef std::tuple<av1_calc_indices_dim1_func, BLOCK_SIZE>
     av1_calc_indices_dim1Param;
@@ -92,8 +94,9 @@
   const int w = block_size_wide[bsize];
   const int h = block_size_high[bsize];
   const int n = w * h;
-  av1_calc_indices_dim1_c(data_, centroids_, indices1_, n, k);
-  test_impl(data_, centroids_, indices2_, n, k);
+  av1_calc_indices_dim1_c(data_, centroids_, indices1_, /*total_dist=*/nullptr,
+                          n, k);
+  test_impl(data_, centroids_, indices2_, /*total_dist=*/nullptr, n, k);
 
   ASSERT_EQ(CheckResult(n), true)
       << " block " << bsize << " index " << n << " Centroids " << k;
@@ -113,7 +116,7 @@
     aom_usec_timer_start(&timer);
     av1_calc_indices_dim1_func func = funcs[i];
     for (int j = 0; j < num_loops; ++j) {
-      func(data_, centroids_, indices1_, n, k);
+      func(data_, centroids_, indices1_, /*total_dist=*/nullptr, n, k);
     }
     aom_usec_timer_mark(&timer);
     double time = static_cast<double>(aom_usec_timer_elapsed(&timer));
@@ -200,8 +203,9 @@
   const int w = block_size_wide[bsize];
   const int h = block_size_high[bsize];
   const int n = w * h;
-  av1_calc_indices_dim2_c(data_, centroids_, indices1_, n, k);
-  test_impl(data_, centroids_, indices2_, n, k);
+  av1_calc_indices_dim2_c(data_, centroids_, indices1_, /*total_dist=*/nullptr,
+                          n, k);
+  test_impl(data_, centroids_, indices2_, /*total_dist=*/nullptr, n, k);
 
   ASSERT_EQ(CheckResult(n), true)
       << " block " << bsize << " index " << n << " Centroids " << k;
@@ -221,7 +225,7 @@
     aom_usec_timer_start(&timer);
     av1_calc_indices_dim2_func func = funcs[i];
     for (int j = 0; j < num_loops; ++j) {
-      func(data_, centroids_, indices1_, n, k);
+      func(data_, centroids_, indices1_, nullptr, n, k);
     }
     aom_usec_timer_mark(&timer);
     double time = static_cast<double>(aom_usec_timer_elapsed(&timer));