K-means: misc refactoring and speed-ups.
Previous suggestions are implemented:
- store centroids once and for all
- do not store all distance: only the current one and the minimum one
- store twice more indices in 2d AVX2
On rtc_screen, speed 0, 50 runs, speed-ups are in %:
File encoding_spdup:
screen_recording_crd.1920_1080.y4m 0.289
screenshare_buganizer.1900_1306.y4m 0.136
screenshare_colorslides.1820_1320.y4m 0.232
screenshare_slidechanges.1850_1110.y4m 0.110
screenshare_youtube.1680_1178.y4m 0.534
slides_webplot.1920_1080.y4m 0.112
sc_web_browsing720p.y4m 0.548
screen_crd_colwinscroll.1920_1128.y4m 0.225
{OVERALL} 0.273
Speed 10:
File encoding_spdup:
screen_recording_crd.1920_1080.y4m 0.277
screenshare_buganizer.1900_1306.y4m 0.382
screenshare_colorslides.1820_1320.y4m 0.391
screenshare_slidechanges.1850_1110.y4m 0.341
screenshare_youtube.1680_1178.y4m 0.813
slides_webplot.1920_1080.y4m 0.287
sc_web_browsing720p.y4m 0.611
screen_crd_colwinscroll.1920_1128.y4m 0.234
{OVERALL} 0.417
Change-Id: I7178ec5c89c9667aac1818c9eb1eda33ef8f81a9
diff --git a/av1/encoder/x86/av1_k_means_avx2.c b/av1/encoder/x86/av1_k_means_avx2.c
index a2db222..ad0b374 100644
--- a/av1/encoder/x86/av1_k_means_avx2.c
+++ b/av1/encoder/x86/av1_k_means_avx2.c
@@ -26,39 +26,44 @@
void av1_calc_indices_dim1_avx2(const int16_t *data, const int16_t *centroids,
uint8_t *indices, int64_t *total_dist, int n,
int k) {
- __m256i dist[PALETTE_MAX_SIZE];
const __m256i v_zero = _mm256_setzero_si256();
__m256i sum = _mm256_setzero_si256();
+ __m256i cents[PALETTE_MAX_SIZE];
+ for (int j = 0; j < k; ++j) {
+ cents[j] = _mm256_set1_epi16(centroids[j]);
+ }
for (int i = 0; i < n; i += 16) {
- __m256i ind = _mm256_loadu_si256((__m256i *)data);
- for (int j = 0; j < k; j++) {
- __m256i cent = _mm256_set1_epi16(centroids[j]);
- __m256i d1 = _mm256_sub_epi16(ind, cent);
- dist[j] = _mm256_abs_epi16(d1);
- }
+ const __m256i in = _mm256_loadu_si256((__m256i *)data);
+ __m256i ind = _mm256_setzero_si256();
+ // Compute the distance to the first centroid.
+ __m256i d1 = _mm256_sub_epi16(in, cents[0]);
+ __m256i dist_min = _mm256_abs_epi16(d1);
- ind = _mm256_setzero_si256();
- for (int j = 1; j < k; j++) {
- __m256i cmp = _mm256_cmpgt_epi16(dist[0], dist[j]);
- dist[0] = _mm256_min_epi16(dist[0], dist[j]);
- __m256i ind1 = _mm256_set1_epi16(j);
+ for (int j = 1; j < k; ++j) {
+ // Compute the distance to the centroid.
+ d1 = _mm256_sub_epi16(in, cents[j]);
+ const __m256i dist = _mm256_abs_epi16(d1);
+ // Compare to the minimal one.
+ const __m256i cmp = _mm256_cmpgt_epi16(dist_min, dist);
+ dist_min = _mm256_min_epi16(dist_min, dist);
+ const __m256i ind1 = _mm256_set1_epi16(j);
ind = _mm256_or_si256(_mm256_andnot_si256(cmp, ind),
_mm256_and_si256(cmp, ind1));
}
- __m256i p1 = _mm256_packus_epi16(ind, v_zero);
- __m256i px = _mm256_permute4x64_epi64(p1, 0x58);
- __m128i d1 = _mm256_extracti128_si256(px, 0);
+ const __m256i p1 = _mm256_packus_epi16(ind, v_zero);
+ const __m256i px = _mm256_permute4x64_epi64(p1, 0x58);
+ const __m128i d2 = _mm256_extracti128_si256(px, 0);
- _mm_storeu_si128((__m128i *)indices, d1);
+ _mm_storeu_si128((__m128i *)indices, d2);
if (total_dist) {
// Square, convert to 32 bit and add together.
- dist[0] = _mm256_madd_epi16(dist[0], dist[0]);
+ dist_min = _mm256_madd_epi16(dist_min, dist_min);
// Convert to 64 bit and add to sum.
- const __m256i dist1 = _mm256_unpacklo_epi32(dist[0], v_zero);
- const __m256i dist2 = _mm256_unpackhi_epi32(dist[0], v_zero);
+ const __m256i dist1 = _mm256_unpacklo_epi32(dist_min, v_zero);
+ const __m256i dist2 = _mm256_unpackhi_epi32(dist_min, v_zero);
sum = _mm256_add_epi64(sum, dist1);
sum = _mm256_add_epi64(sum, dist2);
}
@@ -74,46 +79,52 @@
void av1_calc_indices_dim2_avx2(const int16_t *data, const int16_t *centroids,
uint8_t *indices, int64_t *total_dist, int n,
int k) {
- __m256i dist[PALETTE_MAX_SIZE];
const __m256i v_zero = _mm256_setzero_si256();
+ const __m256i permute = _mm256_set_epi32(0, 0, 0, 0, 5, 1, 4, 0);
__m256i sum = _mm256_setzero_si256();
+ __m256i ind[2];
+ __m256i cents[PALETTE_MAX_SIZE];
+ for (int j = 0; j < k; ++j) {
+ const int16_t cx = centroids[2 * j], cy = centroids[2 * j + 1];
+ cents[j] = _mm256_set_epi16(cy, cx, cy, cx, cy, cx, cy, cx, cy, cx, cy, cx,
+ cy, cx, cy, cx);
+ }
- for (int i = 0; i < n; i += 8) {
- __m256i ind = _mm256_loadu_si256((__m256i *)data);
- for (int j = 0; j < k; j++) {
- const int16_t cx = centroids[2 * j], cy = centroids[2 * j + 1];
- const __m256i cent = _mm256_set_epi16(cy, cx, cy, cx, cy, cx, cy, cx, cy,
- cx, cy, cx, cy, cx, cy, cx);
- const __m256i d1 = _mm256_sub_epi16(ind, cent);
- dist[j] = _mm256_madd_epi16(d1, d1);
+ for (int i = 0; i < n; i += 16) {
+ for (int l = 0; l < 2; ++l) {
+ const __m256i in = _mm256_loadu_si256((__m256i *)data);
+ ind[l] = _mm256_setzero_si256();
+ // Compute the distance to the first centroid.
+ __m256i d1 = _mm256_sub_epi16(in, cents[0]);
+ __m256i dist_min = _mm256_madd_epi16(d1, d1);
+
+ for (int j = 1; j < k; ++j) {
+ // Compute the distance to the centroid.
+ d1 = _mm256_sub_epi16(in, cents[j]);
+ const __m256i dist = _mm256_madd_epi16(d1, d1);
+ // Compare to the minimal one.
+ const __m256i cmp = _mm256_cmpgt_epi32(dist_min, dist);
+ dist_min = _mm256_min_epi32(dist_min, dist);
+ const __m256i ind1 = _mm256_set1_epi32(j);
+ ind[l] = _mm256_or_si256(_mm256_andnot_si256(cmp, ind[l]),
+ _mm256_and_si256(cmp, ind1));
+ }
+ if (total_dist) {
+ // Convert to 64 bit and add to sum.
+ const __m256i dist1 = _mm256_unpacklo_epi32(dist_min, v_zero);
+ const __m256i dist2 = _mm256_unpackhi_epi32(dist_min, v_zero);
+ sum = _mm256_add_epi64(sum, dist1);
+ sum = _mm256_add_epi64(sum, dist2);
+ }
+ data += 16;
}
-
- ind = _mm256_setzero_si256();
- for (int j = 1; j < k; j++) {
- __m256i cmp = _mm256_cmpgt_epi32(dist[0], dist[j]);
- dist[0] = _mm256_min_epi32(dist[0], dist[j]);
- const __m256i ind1 = _mm256_set1_epi32(j);
- ind = _mm256_or_si256(_mm256_andnot_si256(cmp, ind),
- _mm256_and_si256(cmp, ind1));
- }
-
- __m256i p1 = _mm256_packus_epi32(ind, v_zero);
- __m256i px = _mm256_permute4x64_epi64(p1, 0x58);
- __m256i p2 = _mm256_packus_epi16(px, v_zero);
- __m128i d1 = _mm256_extracti128_si256(p2, 0);
-
- _mm_storel_epi64((__m128i *)indices, d1);
-
- if (total_dist) {
- // Convert to 64 bit and add to sum.
- const __m256i dist1 = _mm256_unpacklo_epi32(dist[0], v_zero);
- const __m256i dist2 = _mm256_unpackhi_epi32(dist[0], v_zero);
- sum = _mm256_add_epi64(sum, dist1);
- sum = _mm256_add_epi64(sum, dist2);
- }
-
- indices += 8;
- data += 16;
+ // Cast to 8 bit and store.
+ const __m256i d2 = _mm256_packus_epi32(ind[0], ind[1]);
+ const __m256i d3 = _mm256_packus_epi16(d2, v_zero);
+ const __m256i d4 = _mm256_permutevar8x32_epi32(d3, permute);
+ const __m128i d5 = _mm256_extracti128_si256(d4, 0);
+ _mm_storeu_si128((__m128i *)indices, d5);
+ indices += 16;
}
if (total_dist) {
*total_dist = k_means_horizontal_sum_avx2(sum);
diff --git a/av1/encoder/x86/av1_k_means_sse2.c b/av1/encoder/x86/av1_k_means_sse2.c
index a284fa9..4338bf7 100644
--- a/av1/encoder/x86/av1_k_means_sse2.c
+++ b/av1/encoder/x86/av1_k_means_sse2.c
@@ -26,31 +26,37 @@
uint8_t *indices, int64_t *total_dist, int n,
int k) {
const __m128i v_zero = _mm_setzero_si128();
- __m128i dist[PALETTE_MAX_SIZE];
__m128i sum = _mm_setzero_si128();
+ __m128i cents[PALETTE_MAX_SIZE];
+ for (int j = 0; j < k; ++j) {
+ cents[j] = _mm_set1_epi16(centroids[j]);
+ }
for (int i = 0; i < n; i += 8) {
- __m128i in = _mm_loadu_si128((__m128i *)data);
- for (int j = 0; j < k; j++) {
- __m128i cent = _mm_set1_epi16(centroids[j]);
- __m128i d1 = _mm_sub_epi16(in, cent);
- __m128i d2 = _mm_sub_epi16(cent, in);
- dist[j] = _mm_max_epi16(d1, d2);
- }
-
+ const __m128i in = _mm_loadu_si128((__m128i *)data);
__m128i ind = _mm_setzero_si128();
- for (int j = 1; j < k; j++) {
- __m128i cmp = _mm_cmpgt_epi16(dist[0], dist[j]);
- dist[0] = _mm_min_epi16(dist[0], dist[j]);
- __m128i ind1 = _mm_set1_epi16(j);
+ // Compute the distance to the first centroid.
+ __m128i d1 = _mm_sub_epi16(in, cents[0]);
+ __m128i d2 = _mm_sub_epi16(cents[0], in);
+ __m128i dist_min = _mm_max_epi16(d1, d2);
+
+ for (int j = 1; j < k; ++j) {
+ // Compute the distance to the centroid.
+ d1 = _mm_sub_epi16(in, cents[j]);
+ d2 = _mm_sub_epi16(cents[j], in);
+ const __m128i dist = _mm_max_epi16(d1, d2);
+ // Compare to the minimal one.
+ const __m128i cmp = _mm_cmpgt_epi16(dist_min, dist);
+ dist_min = _mm_min_epi16(dist_min, dist);
+ const __m128i ind1 = _mm_set1_epi16(j);
ind = _mm_or_si128(_mm_andnot_si128(cmp, ind), _mm_and_si128(cmp, ind1));
}
if (total_dist) {
// Square, convert to 32 bit and add together.
- dist[0] = _mm_madd_epi16(dist[0], dist[0]);
+ dist_min = _mm_madd_epi16(dist_min, dist_min);
// Convert to 64 bit and add to sum.
- const __m128i dist1 = _mm_unpacklo_epi32(dist[0], v_zero);
- const __m128i dist2 = _mm_unpackhi_epi32(dist[0], v_zero);
+ const __m128i dist1 = _mm_unpacklo_epi32(dist_min, v_zero);
+ const __m128i dist2 = _mm_unpackhi_epi32(dist_min, v_zero);
sum = _mm_add_epi64(sum, dist1);
sum = _mm_add_epi64(sum, dist2);
}
@@ -68,45 +74,49 @@
uint8_t *indices, int64_t *total_dist, int n,
int k) {
const __m128i v_zero = _mm_setzero_si128();
- int l = 1;
- __m128i dist[PALETTE_MAX_SIZE];
- __m128i ind[2];
__m128i sum = _mm_setzero_si128();
+ __m128i ind[2];
+ __m128i cents[PALETTE_MAX_SIZE];
+ for (int j = 0; j < k; ++j) {
+ const int16_t cx = centroids[2 * j], cy = centroids[2 * j + 1];
+ cents[j] = _mm_set_epi16(cy, cx, cy, cx, cy, cx, cy, cx);
+ }
- for (int i = 0; i < n; i += 4) {
- l = (l == 0) ? 1 : 0;
- __m128i ind1 = _mm_loadu_si128((__m128i *)data);
- for (int j = 0; j < k; j++) {
- const int16_t cx = centroids[2 * j], cy = centroids[2 * j + 1];
- const __m128i cent = _mm_set_epi16(cy, cx, cy, cx, cy, cx, cy, cx);
- const __m128i d1 = _mm_sub_epi16(ind1, cent);
- dist[j] = _mm_madd_epi16(d1, d1);
- }
+ for (int i = 0; i < n; i += 8) {
+ for (int l = 0; l < 2; ++l) {
+ const __m128i in = _mm_loadu_si128((__m128i *)data);
+ ind[l] = _mm_setzero_si128();
+ // Compute the distance to the first centroid.
+ __m128i d1 = _mm_sub_epi16(in, cents[0]);
+ __m128i dist_min = _mm_madd_epi16(d1, d1);
- ind[l] = _mm_setzero_si128();
- for (int j = 1; j < k; j++) {
- __m128i cmp = _mm_cmpgt_epi32(dist[0], dist[j]);
- __m128i dist1 = _mm_andnot_si128(cmp, dist[0]);
- __m128i dist2 = _mm_and_si128(cmp, dist[j]);
- dist[0] = _mm_or_si128(dist1, dist2);
- ind1 = _mm_set1_epi32(j);
- ind[l] =
- _mm_or_si128(_mm_andnot_si128(cmp, ind[l]), _mm_and_si128(cmp, ind1));
+ for (int j = 1; j < k; ++j) {
+ // Compute the distance to the centroid.
+ d1 = _mm_sub_epi16(in, cents[j]);
+ const __m128i dist = _mm_madd_epi16(d1, d1);
+ // Compare to the minimal one.
+ const __m128i cmp = _mm_cmpgt_epi32(dist_min, dist);
+ const __m128i dist1 = _mm_andnot_si128(cmp, dist_min);
+ const __m128i dist2 = _mm_and_si128(cmp, dist);
+ dist_min = _mm_or_si128(dist1, dist2);
+ const __m128i ind1 = _mm_set1_epi32(j);
+ ind[l] = _mm_or_si128(_mm_andnot_si128(cmp, ind[l]),
+ _mm_and_si128(cmp, ind1));
+ }
+ if (total_dist) {
+ // Convert to 64 bit and add to sum.
+ const __m128i dist1 = _mm_unpacklo_epi32(dist_min, v_zero);
+ const __m128i dist2 = _mm_unpackhi_epi32(dist_min, v_zero);
+ sum = _mm_add_epi64(sum, dist1);
+ sum = _mm_add_epi64(sum, dist2);
+ }
+ data += 8;
}
- ind[l] = _mm_packus_epi16(ind[l], v_zero);
- if (total_dist) {
- // Convert to 64 bit and add to sum.
- const __m128i dist1 = _mm_unpacklo_epi32(dist[0], v_zero);
- const __m128i dist2 = _mm_unpackhi_epi32(dist[0], v_zero);
- sum = _mm_add_epi64(sum, dist1);
- sum = _mm_add_epi64(sum, dist2);
- }
- if (l == 1) {
- __m128i p2 = _mm_packus_epi16(_mm_unpacklo_epi64(ind[0], ind[1]), v_zero);
- _mm_storel_epi64((__m128i *)indices, p2);
- indices += 8;
- }
- data += 8;
+ // Cast to 8 bit and store.
+ const __m128i d2 = _mm_packus_epi16(ind[0], ind[1]);
+ const __m128i d3 = _mm_packus_epi16(d2, v_zero);
+ _mm_storel_epi64((__m128i *)indices, d3);
+ indices += 8;
}
if (total_dist) {
*total_dist = k_means_horizontal_sum_sse2(sum);