K-means: Compute calc_total_dist inside calc_indices.
Bug: b/217282899
On rtc_screen, speed 0, 50 runs, speed-ups are in %:
screen_recording_crd.1920_1080.y4m 0.461
screenshare_buganizer.1900_1306.y4m 0.376
screenshare_colorslides.1820_1320.y4m 0.601
screenshare_slidechanges.1850_1110.y4m 0.410
screenshare_youtube.1680_1178.y4m 0.338
slides_webplot.1920_1080.y4m 0.298
sc_web_browsing720p.y4m 0.901
screen_crd_colwinscroll.1920_1128.y4m 0.082
{OVERALL} 0.433
At speed 10:
screen_recording_crd.1920_1080.y4m 0.334
screenshare_buganizer.1900_1306.y4m 0.774
screenshare_colorslides.1820_1320.y4m 0.678
screenshare_slidechanges.1850_1110.y4m 0.791
screenshare_youtube.1680_1178.y4m 0.170
slides_webplot.1920_1080.y4m 0.640
sc_web_browsing720p.y4m 0.914
screen_crd_colwinscroll.1920_1128.y4m 0.066
{OVERALL} 0.546
Change-Id: Ic2f48e586eae2df96f1e3b0e11b76332f7b76ef9
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index 5a05548..ca299aa 100644
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -412,10 +412,10 @@
add_proto qw/void av1_quantize_b/, "const tran_low_t *coeff_ptr, intptr_t n_coeffs, const int16_t *zbin_ptr, const int16_t *round_ptr, const int16_t *quant_ptr, const int16_t *quant_shift_ptr, tran_low_t *qcoeff_ptr, tran_low_t *dqcoeff_ptr, const int16_t *dequant_ptr, uint16_t *eob_ptr, const int16_t *scan, const int16_t *iscan, const qm_val_t * qm_ptr, const qm_val_t * iqm_ptr, int log_scale";
- add_proto qw/void av1_calc_indices_dim1/, "const int *data, const int *centroids, uint8_t *indices, int n, int k";
+ add_proto qw/void av1_calc_indices_dim1/, "const int *data, const int *centroids, uint8_t *indices, int64_t *total_dist, int n, int k";
specialize qw/av1_calc_indices_dim1 sse2 avx2/;
- add_proto qw/void av1_calc_indices_dim2/, "const int *data, const int *centroids, uint8_t *indices, int n, int k";
+ add_proto qw/void av1_calc_indices_dim2/, "const int *data, const int *centroids, uint8_t *indices, int64_t *total_dist, int n, int k";
specialize qw/av1_calc_indices_dim2 sse2 avx2/;
# ENCODEMB INVOKE
diff --git a/av1/encoder/k_means_template.h b/av1/encoder/k_means_template.h
index 87677a1..fe3e9ef 100644
--- a/av1/encoder/k_means_template.h
+++ b/av1/encoder/k_means_template.h
@@ -34,7 +34,10 @@
}
void RENAME(av1_calc_indices)(const int *data, const int *centroids,
- uint8_t *indices, int n, int k) {
+ uint8_t *indices, int64_t *dist, int n, int k) {
+ if (dist) {
+ *dist = 0;
+ }
for (int i = 0; i < n; ++i) {
int min_dist = RENAME(calc_dist)(data + i * AV1_K_MEANS_DIM, centroids);
indices[i] = 0;
@@ -46,6 +49,9 @@
indices[i] = j;
}
}
+ if (dist) {
+ dist += min_dist;
+ }
}
}
@@ -80,17 +86,6 @@
}
}
-static int64_t RENAME(calc_total_dist)(const int *data, const int *centroids,
- const uint8_t *indices, int n, int k) {
- int64_t dist = 0;
- (void)k;
- for (int i = 0; i < n; ++i) {
- dist += RENAME(calc_dist)(data + i * AV1_K_MEANS_DIM,
- centroids + indices[i] * AV1_K_MEANS_DIM);
- }
- return dist;
-}
-
void RENAME(av1_k_means)(const int *data, int *centroids, uint8_t *indices,
int n, int k, int max_itr) {
int centroids_tmp[AV1_K_MEANS_DIM * PALETTE_MAX_SIZE];
@@ -103,11 +98,10 @@
assert(n <= MAX_PALETTE_BLOCK_WIDTH * MAX_PALETTE_BLOCK_HEIGHT);
#if AV1_K_MEANS_DIM - 2
- av1_calc_indices_dim1(data, centroids, indices, n, k);
+ av1_calc_indices_dim1(data, centroids, indices, &this_dist, n, k);
#else
- av1_calc_indices_dim2(data, centroids, indices, n, k);
+ av1_calc_indices_dim2(data, centroids, indices, &this_dist, n, k);
#endif
- this_dist = RENAME(calc_total_dist)(data, centroids, indices, n, k);
for (i = 0; i < max_itr; ++i) {
const int64_t prev_dist = this_dist;
@@ -116,12 +110,12 @@
RENAME(calc_centroids)(data, meta_centroids[l], meta_indices[prev_l], n, k);
#if AV1_K_MEANS_DIM - 2
- av1_calc_indices_dim1(data, meta_centroids[l], meta_indices[l], n, k);
+ av1_calc_indices_dim1(data, meta_centroids[l], meta_indices[l], &this_dist,
+ n, k);
#else
- av1_calc_indices_dim2(data, meta_centroids[l], meta_indices[l], n, k);
+ av1_calc_indices_dim2(data, meta_centroids[l], meta_indices[l], &this_dist,
+ n, k);
#endif
- this_dist =
- RENAME(calc_total_dist)(data, meta_centroids[l], meta_indices[l], n, k);
if (this_dist > prev_dist) {
best_l = prev_l;
diff --git a/av1/encoder/palette.h b/av1/encoder/palette.h
index 34d2ddc..6f33f44 100644
--- a/av1/encoder/palette.h
+++ b/av1/encoder/palette.h
@@ -56,9 +56,9 @@
assert(n > 0);
assert(k > 0);
if (dim == 1) {
- av1_calc_indices_dim1(data, centroids, indices, n, k);
+ av1_calc_indices_dim1(data, centroids, indices, /*total_dist=*/NULL, n, k);
} else if (dim == 2) {
- av1_calc_indices_dim2(data, centroids, indices, n, k);
+ av1_calc_indices_dim2(data, centroids, indices, /*total_dist=*/NULL, n, k);
} else {
assert(0 && "Untemplated k means dimension");
}
diff --git a/av1/encoder/x86/av1_k_means_avx2.c b/av1/encoder/x86/av1_k_means_avx2.c
index 759f515..2177358 100644
--- a/av1/encoder/x86/av1_k_means_avx2.c
+++ b/av1/encoder/x86/av1_k_means_avx2.c
@@ -13,10 +13,18 @@
#include "config/aom_dsp_rtcd.h"
#include "aom_dsp/x86/synonyms.h"
+static int64_t k_means_horizontal_sum_avx2(__m256i a) {
+ int64_t dists[4];
+ _mm256_store_si256((__m256i *)dists, a);
+ return a[0] + a[1] + a[2] + a[3];
+}
+
void av1_calc_indices_dim1_avx2(const int *data, const int *centroids,
- uint8_t *indices, int n, int k) {
+ uint8_t *indices, int64_t *total_dist, int n,
+ int k) {
__m256i dist[PALETTE_MAX_SIZE];
const __m256i v_zero = _mm256_setzero_si256();
+ __m256i sum = _mm256_setzero_si256();
for (int i = 0; i < n; i += 8) {
__m256i ind = _mm256_loadu_si256((__m256i *)data);
@@ -44,16 +52,29 @@
_mm_storel_epi64((__m128i *)indices, d1);
+ if (total_dist) {
+ // Convert to 64 bit and add to sum.
+ const __m256i dist1 = _mm256_unpacklo_epi32(v_zero, dist[0]);
+ const __m256i dist2 = _mm256_unpackhi_epi32(v_zero, dist[0]);
+ sum = _mm256_add_epi64(sum, dist1);
+ sum = _mm256_add_epi64(sum, dist2);
+ }
+
indices += 8;
data += 8;
}
+ if (total_dist) {
+ *total_dist = k_means_horizontal_sum_avx2(sum);
+ }
}
void av1_calc_indices_dim2_avx2(const int *data, const int *centroids,
- uint8_t *indices, int n, int k) {
+ uint8_t *indices, int64_t *total_dist, int n,
+ int k) {
__m256i dist[PALETTE_MAX_SIZE];
const __m256i v_zero = _mm256_setzero_si256();
const __m256i v_permute = _mm256_setr_epi32(0, 1, 4, 5, 2, 3, 6, 7);
+ __m256i sum = _mm256_setzero_si256();
for (int i = 0; i < n; i += 8) {
__m256i ind1 = _mm256_loadu_si256((__m256i *)data);
@@ -89,7 +110,18 @@
_mm_storel_epi64((__m128i *)indices, d1);
+ if (total_dist) {
+ // Convert to 64 bit and add to sum.
+ const __m256i dist1 = _mm256_unpacklo_epi32(v_zero, dist[0]);
+ const __m256i dist2 = _mm256_unpackhi_epi32(v_zero, dist[0]);
+ sum = _mm256_add_epi64(sum, dist1);
+ sum = _mm256_add_epi64(sum, dist2);
+ }
+
indices += 8;
data += 16;
}
+ if (total_dist) {
+ *total_dist = k_means_horizontal_sum_avx2(sum);
+ }
}
diff --git a/av1/encoder/x86/av1_k_means_sse2.c b/av1/encoder/x86/av1_k_means_sse2.c
index f03c459..df823a0 100644
--- a/av1/encoder/x86/av1_k_means_sse2.c
+++ b/av1/encoder/x86/av1_k_means_sse2.c
@@ -14,12 +14,22 @@
#include "config/aom_dsp_rtcd.h"
#include "aom_dsp/x86/synonyms.h"
+static int64_t k_means_horizontal_sum_sse2(__m128i a) {
+ const __m128i sum1 = _mm_unpackhi_epi64(a, a);
+ const __m128i sum2 = _mm_add_epi64(a, sum1);
+ int64_t res;
+ _mm_storel_epi64((__m128i_u *)(&res), sum2);
+ return res;
+}
+
void av1_calc_indices_dim1_sse2(const int *data, const int *centroids,
- uint8_t *indices, int n, int k) {
+ uint8_t *indices, int64_t *total_dist, int n,
+ int k) {
const __m128i v_zero = _mm_setzero_si128();
int l = 1;
__m128i dist[PALETTE_MAX_SIZE];
__m128i ind[2];
+ __m128i sum = _mm_setzero_si128();
for (int i = 0; i < n; i += 4) {
l = (l == 0) ? 1 : 0;
@@ -44,6 +54,13 @@
_mm_or_si128(_mm_andnot_si128(cmp, ind[l]), _mm_and_si128(cmp, ind1));
}
ind[l] = _mm_packus_epi16(ind[l], v_zero);
+ if (total_dist) {
+ // Convert to 64 bit and add to sum.
+ const __m128i dist1 = _mm_unpacklo_epi32(v_zero, dist[0]);
+ const __m128i dist2 = _mm_unpackhi_epi32(v_zero, dist[0]);
+ sum = _mm_add_epi64(sum, dist1);
+ sum = _mm_add_epi64(sum, dist2);
+ }
if (l == 1) {
__m128i p2 = _mm_packus_epi16(_mm_unpacklo_epi64(ind[0], ind[1]), v_zero);
_mm_storel_epi64((__m128i *)indices, p2);
@@ -51,14 +68,19 @@
}
data += 4;
}
+ if (total_dist) {
+ *total_dist = k_means_horizontal_sum_sse2(sum);
+ }
}
void av1_calc_indices_dim2_sse2(const int *data, const int *centroids,
- uint8_t *indices, int n, int k) {
+ uint8_t *indices, int64_t *total_dist, int n,
+ int k) {
const __m128i v_zero = _mm_setzero_si128();
int l = 1;
__m128i dist[PALETTE_MAX_SIZE];
__m128i ind[2];
+ __m128i sum = _mm_setzero_si128();
for (int i = 0; i < n; i += 4) {
l = (l == 0) ? 1 : 0;
@@ -89,6 +111,13 @@
_mm_or_si128(_mm_andnot_si128(cmp, ind[l]), _mm_and_si128(cmp, ind1));
}
ind[l] = _mm_packus_epi16(ind[l], v_zero);
+ if (total_dist) {
+ // Convert to 64 bit and add to sum.
+ const __m128i dist1 = _mm_unpacklo_epi32(v_zero, dist[0]);
+ const __m128i dist2 = _mm_unpackhi_epi32(v_zero, dist[0]);
+ sum = _mm_add_epi64(sum, dist1);
+ sum = _mm_add_epi64(sum, dist2);
+ }
if (l == 1) {
__m128i p2 = _mm_packus_epi16(_mm_unpacklo_epi64(ind[0], ind[1]), v_zero);
_mm_storel_epi64((__m128i *)indices, p2);
@@ -96,4 +125,7 @@
}
data += 8;
}
+ if (total_dist) {
+ *total_dist = k_means_horizontal_sum_sse2(sum);
+ }
}
diff --git a/test/av1_k_means_test.cc b/test/av1_k_means_test.cc
index 8a85323..2ef837c 100644
--- a/test/av1_k_means_test.cc
+++ b/test/av1_k_means_test.cc
@@ -30,10 +30,12 @@
namespace AV1Kmeans {
typedef void (*av1_calc_indices_dim1_func)(const int *data,
const int *centroids,
- uint8_t *indices, int n, int k);
+ uint8_t *indices,
+ int64_t *total_dist, int n, int k);
typedef void (*av1_calc_indices_dim2_func)(const int *data,
const int *centroids,
- uint8_t *indices, int n, int k);
+ uint8_t *indices,
+ int64_t *total_dist, int n, int k);
typedef std::tuple<av1_calc_indices_dim1_func, BLOCK_SIZE>
av1_calc_indices_dim1Param;
@@ -92,8 +94,9 @@
const int w = block_size_wide[bsize];
const int h = block_size_high[bsize];
const int n = w * h;
- av1_calc_indices_dim1_c(data_, centroids_, indices1_, n, k);
- test_impl(data_, centroids_, indices2_, n, k);
+ av1_calc_indices_dim1_c(data_, centroids_, indices1_, /*total_dist=*/nullptr,
+ n, k);
+ test_impl(data_, centroids_, indices2_, /*total_dist=*/nullptr, n, k);
ASSERT_EQ(CheckResult(n), true)
<< " block " << bsize << " index " << n << " Centroids " << k;
@@ -113,7 +116,7 @@
aom_usec_timer_start(&timer);
av1_calc_indices_dim1_func func = funcs[i];
for (int j = 0; j < num_loops; ++j) {
- func(data_, centroids_, indices1_, n, k);
+ func(data_, centroids_, indices1_, /*total_dist=*/nullptr, n, k);
}
aom_usec_timer_mark(&timer);
double time = static_cast<double>(aom_usec_timer_elapsed(&timer));
@@ -200,8 +203,9 @@
const int w = block_size_wide[bsize];
const int h = block_size_high[bsize];
const int n = w * h;
- av1_calc_indices_dim2_c(data_, centroids_, indices1_, n, k);
- test_impl(data_, centroids_, indices2_, n, k);
+ av1_calc_indices_dim2_c(data_, centroids_, indices1_, /*total_dist=*/nullptr,
+ n, k);
+ test_impl(data_, centroids_, indices2_, /*total_dist=*/nullptr, n, k);
ASSERT_EQ(CheckResult(n), true)
<< " block " << bsize << " index " << n << " Centroids " << k;
@@ -221,7 +225,7 @@
aom_usec_timer_start(&timer);
av1_calc_indices_dim2_func func = funcs[i];
for (int j = 0; j < num_loops; ++j) {
- func(data_, centroids_, indices1_, n, k);
+ func(data_, centroids_, indices1_, nullptr, n, k);
}
aom_usec_timer_mark(&timer);
double time = static_cast<double>(aom_usec_timer_elapsed(&timer));