K-means: use L1 norm when dim == 1.
In 1 dimension, finding the nearest neighbor using L2 or L1 is
equivalent. Only the best distance needs to be squared.
The speed-up is negligible: the goal is to make the code easier
to transition to 16-bit inputs.
Change-Id: Ibbff387aec1055655b4d96f3b28d0dcfddd80e6e
diff --git a/av1/encoder/k_means_template.h b/av1/encoder/k_means_template.h
index 9560da5..0cf72b7 100644
--- a/av1/encoder/k_means_template.h
+++ b/av1/encoder/k_means_template.h
@@ -11,6 +11,7 @@
#include <assert.h>
#include <stdint.h>
+#include <stdlib.h>
#include <string.h>
#include "av1/common/blockd.h"
@@ -24,13 +25,20 @@
#define RENAME_(x, y) AV1_K_MEANS_RENAME(x, y)
#define RENAME(x) RENAME_(x, AV1_K_MEANS_DIM)
+// Though we want to compute the smallest L2 norm, in 1 dimension,
+// it is equivalent to find the smallest L1 norm and then square it.
+// This is preferrable for speed, especially on the SIMD side.
static int RENAME(calc_dist)(const int *p1, const int *p2) {
+#if AV1_K_MEANS_DIM == 1
+ return abs(p1[0] - p2[0]);
+#else
int dist = 0;
for (int i = 0; i < AV1_K_MEANS_DIM; ++i) {
const int diff = p1[i] - p2[i];
dist += diff * diff;
}
return dist;
+#endif
}
void RENAME(av1_calc_indices)(const int *data, const int *centroids,
@@ -50,7 +58,11 @@
}
}
if (dist) {
+#if AV1_K_MEANS_DIM == 1
+ *dist += min_dist * min_dist;
+#else
*dist += min_dist;
+#endif
}
}
}
@@ -97,7 +109,7 @@
assert(n <= MAX_PALETTE_BLOCK_WIDTH * MAX_PALETTE_BLOCK_HEIGHT);
-#if AV1_K_MEANS_DIM - 2
+#if AV1_K_MEANS_DIM == 1
av1_calc_indices_dim1(data, centroids, indices, &this_dist, n, k);
#else
av1_calc_indices_dim2(data, centroids, indices, &this_dist, n, k);
@@ -109,7 +121,7 @@
l = (l == 1) ? 0 : 1;
RENAME(calc_centroids)(data, meta_centroids[l], meta_indices[prev_l], n, k);
-#if AV1_K_MEANS_DIM - 2
+#if AV1_K_MEANS_DIM == 1
av1_calc_indices_dim1(data, meta_centroids[l], meta_indices[l], &this_dist,
n, k);
#else
diff --git a/av1/encoder/x86/av1_k_means_avx2.c b/av1/encoder/x86/av1_k_means_avx2.c
index 2745ac1..854b90c 100644
--- a/av1/encoder/x86/av1_k_means_avx2.c
+++ b/av1/encoder/x86/av1_k_means_avx2.c
@@ -35,7 +35,7 @@
for (int j = 0; j < k; j++) {
__m256i cent = _mm256_set1_epi32(centroids[j]);
__m256i d1 = _mm256_sub_epi32(ind, cent);
- dist[j] = _mm256_mullo_epi32(d1, d1);
+ dist[j] = _mm256_abs_epi32(d1);
}
ind = _mm256_setzero_si256();
@@ -57,7 +57,8 @@
_mm_storel_epi64((__m128i *)indices, d1);
if (total_dist) {
- // Convert to 64 bit and add to sum.
+ // Square, convert to 64 bit and add to sum.
+ dist[0] = _mm256_mullo_epi32(dist[0], dist[0]);
const __m256i dist1 = _mm256_unpacklo_epi32(dist[0], v_zero);
const __m256i dist2 = _mm256_unpackhi_epi32(dist[0], v_zero);
sum = _mm256_add_epi64(sum, dist1);
diff --git a/av1/encoder/x86/av1_k_means_sse2.c b/av1/encoder/x86/av1_k_means_sse2.c
index 2c12346..d2c7796 100644
--- a/av1/encoder/x86/av1_k_means_sse2.c
+++ b/av1/encoder/x86/av1_k_means_sse2.c
@@ -22,6 +22,15 @@
return res;
}
+static __m128i absolute_diff_epi32(__m128i a, __m128i b) {
+ const __m128i diff1 = _mm_sub_epi32(a, b);
+ const __m128i diff2 = _mm_sub_epi32(b, a);
+ const __m128i cmp = _mm_cmpgt_epi32(diff1, diff2);
+ const __m128i masked1 = _mm_and_si128(cmp, diff1);
+ const __m128i masked2 = _mm_andnot_si128(cmp, diff2);
+ return _mm_or_si128(masked1, masked2);
+}
+
void av1_calc_indices_dim1_sse2(const int *data, const int *centroids,
uint8_t *indices, int64_t *total_dist, int n,
int k) {
@@ -36,11 +45,7 @@
ind[l] = _mm_loadu_si128((__m128i *)data);
for (int j = 0; j < k; j++) {
__m128i cent = _mm_set1_epi32(centroids[j]);
- __m128i d1 = _mm_sub_epi32(ind[l], cent);
- __m128i d2 = _mm_packs_epi32(d1, d1);
- __m128i d3 = _mm_mullo_epi16(d2, d2);
- __m128i d4 = _mm_mulhi_epi16(d2, d2);
- dist[j] = _mm_unpacklo_epi16(d3, d4);
+ dist[j] = absolute_diff_epi32(ind[l], cent);
}
ind[l] = _mm_setzero_si128();
@@ -55,6 +60,11 @@
}
ind[l] = _mm_packus_epi16(ind[l], v_zero);
if (total_dist) {
+ // Square and convert to 32 bit.
+ const __m128i d1 = _mm_packs_epi32(dist[0], v_zero);
+ const __m128i d2 = _mm_mullo_epi16(d1, d1);
+ const __m128i d3 = _mm_mulhi_epi16(d1, d1);
+ dist[0] = _mm_unpacklo_epi16(d2, d3);
// Convert to 64 bit and add to sum.
const __m128i dist1 = _mm_unpacklo_epi32(dist[0], v_zero);
const __m128i dist2 = _mm_unpackhi_epi32(dist[0], v_zero);
@@ -73,15 +83,6 @@
}
}
-static __m128i absolute_diff_epi32(__m128i a, __m128i b) {
- const __m128i diff1 = _mm_sub_epi32(a, b);
- const __m128i diff2 = _mm_sub_epi32(b, a);
- const __m128i cmp = _mm_cmpgt_epi32(diff1, diff2);
- const __m128i masked1 = _mm_and_si128(cmp, diff1);
- const __m128i masked2 = _mm_andnot_si128(cmp, diff2);
- return _mm_or_si128(masked1, masked2);
-}
-
void av1_calc_indices_dim2_sse2(const int *data, const int *centroids,
uint8_t *indices, int64_t *total_dist, int n,
int k) {