Fix several bugs in K-means.
Mostly, a test was added to make sure total_dist is the same on all
platforms
Some introduced by the recent 8cbaea62af09e92a81853e1787c9912b3ac6999c:
- dist instead of *dist: the C version was not part of my SSIM check
- the horizontal_sum_avx2 function has been re-written to be pure SSE2
(to avoid non-aligned memory problem)
- int64_t conversions were wrong (though the results right because
only a comparison is needed and the results were shifted by 32 bits)
Some much older:
- av1_calc_indices_dim2_sse2 was actually wrong as it was doing
_mm_madd_epi16 on 32 bit which only works for positive numbers, and
is otherwise creating a difference of 1 (-1 * -1).
The solution is to take the absolute value for now. The proper
solution is to switch the pipeline to 16 bit as this is what it
assumes.
Change-Id: Iaeeefa9f41d5262a622f6b64b9221696a2335c8d
diff --git a/av1/encoder/k_means_template.h b/av1/encoder/k_means_template.h
index fe3e9ef..9560da5 100644
--- a/av1/encoder/k_means_template.h
+++ b/av1/encoder/k_means_template.h
@@ -50,7 +50,7 @@
}
}
if (dist) {
- dist += min_dist;
+ *dist += min_dist;
}
}
}
diff --git a/av1/encoder/x86/av1_k_means_avx2.c b/av1/encoder/x86/av1_k_means_avx2.c
index 2177358..2745ac1 100644
--- a/av1/encoder/x86/av1_k_means_avx2.c
+++ b/av1/encoder/x86/av1_k_means_avx2.c
@@ -14,9 +14,13 @@
#include "aom_dsp/x86/synonyms.h"
static int64_t k_means_horizontal_sum_avx2(__m256i a) {
- int64_t dists[4];
- _mm256_store_si256((__m256i *)dists, a);
- return a[0] + a[1] + a[2] + a[3];
+ const __m128i low = _mm256_castsi256_si128(a);
+ const __m128i high = _mm256_extracti128_si256(a, 1);
+ const __m128i sum = _mm_add_epi64(low, high);
+ const __m128i sum_high = _mm_unpackhi_epi64(sum, sum);
+ int64_t res;
+ _mm_storel_epi64((__m128i *)&res, _mm_add_epi64(sum, sum_high));
+ return res;
}
void av1_calc_indices_dim1_avx2(const int *data, const int *centroids,
@@ -54,8 +58,8 @@
if (total_dist) {
// Convert to 64 bit and add to sum.
- const __m256i dist1 = _mm256_unpacklo_epi32(v_zero, dist[0]);
- const __m256i dist2 = _mm256_unpackhi_epi32(v_zero, dist[0]);
+ const __m256i dist1 = _mm256_unpacklo_epi32(dist[0], v_zero);
+ const __m256i dist2 = _mm256_unpackhi_epi32(dist[0], v_zero);
sum = _mm256_add_epi64(sum, dist1);
sum = _mm256_add_epi64(sum, dist2);
}
@@ -112,8 +116,8 @@
if (total_dist) {
// Convert to 64 bit and add to sum.
- const __m256i dist1 = _mm256_unpacklo_epi32(v_zero, dist[0]);
- const __m256i dist2 = _mm256_unpackhi_epi32(v_zero, dist[0]);
+ const __m256i dist1 = _mm256_unpacklo_epi32(dist[0], v_zero);
+ const __m256i dist2 = _mm256_unpackhi_epi32(dist[0], v_zero);
sum = _mm256_add_epi64(sum, dist1);
sum = _mm256_add_epi64(sum, dist2);
}
diff --git a/av1/encoder/x86/av1_k_means_sse2.c b/av1/encoder/x86/av1_k_means_sse2.c
index df823a0..2c12346 100644
--- a/av1/encoder/x86/av1_k_means_sse2.c
+++ b/av1/encoder/x86/av1_k_means_sse2.c
@@ -18,7 +18,7 @@
const __m128i sum1 = _mm_unpackhi_epi64(a, a);
const __m128i sum2 = _mm_add_epi64(a, sum1);
int64_t res;
- _mm_storel_epi64((__m128i_u *)(&res), sum2);
+ _mm_storel_epi64((__m128i *)&res, sum2);
return res;
}
@@ -56,8 +56,8 @@
ind[l] = _mm_packus_epi16(ind[l], v_zero);
if (total_dist) {
// Convert to 64 bit and add to sum.
- const __m128i dist1 = _mm_unpacklo_epi32(v_zero, dist[0]);
- const __m128i dist2 = _mm_unpackhi_epi32(v_zero, dist[0]);
+ const __m128i dist1 = _mm_unpacklo_epi32(dist[0], v_zero);
+ const __m128i dist2 = _mm_unpackhi_epi32(dist[0], v_zero);
sum = _mm_add_epi64(sum, dist1);
sum = _mm_add_epi64(sum, dist2);
}
@@ -73,6 +73,15 @@
}
}
+static __m128i absolute_diff_epi32(__m128i a, __m128i b) {
+ const __m128i diff1 = _mm_sub_epi32(a, b);
+ const __m128i diff2 = _mm_sub_epi32(b, a);
+ const __m128i cmp = _mm_cmpgt_epi32(diff1, diff2);
+ const __m128i masked1 = _mm_and_si128(cmp, diff1);
+ const __m128i masked2 = _mm_andnot_si128(cmp, diff2);
+ return _mm_or_si128(masked1, masked2);
+}
+
void av1_calc_indices_dim2_sse2(const int *data, const int *centroids,
uint8_t *indices, int64_t *total_dist, int n,
int k) {
@@ -93,8 +102,8 @@
for (int j = 0; j < k; j++) {
__m128i cent0 = _mm_set1_epi32(centroids[2 * j]);
__m128i cent1 = _mm_set1_epi32(centroids[2 * j + 1]);
- __m128i d1 = _mm_sub_epi32(ind1, cent0);
- __m128i d2 = _mm_sub_epi32(ind2, cent1);
+ __m128i d1 = absolute_diff_epi32(ind1, cent0);
+ __m128i d2 = absolute_diff_epi32(ind2, cent1);
__m128i d3 = _mm_madd_epi16(d1, d1);
__m128i d4 = _mm_madd_epi16(d2, d2);
dist[j] = _mm_add_epi32(d3, d4);
@@ -113,8 +122,8 @@
ind[l] = _mm_packus_epi16(ind[l], v_zero);
if (total_dist) {
// Convert to 64 bit and add to sum.
- const __m128i dist1 = _mm_unpacklo_epi32(v_zero, dist[0]);
- const __m128i dist2 = _mm_unpackhi_epi32(v_zero, dist[0]);
+ const __m128i dist1 = _mm_unpacklo_epi32(dist[0], v_zero);
+ const __m128i dist2 = _mm_unpackhi_epi32(dist[0], v_zero);
sum = _mm_add_epi64(sum, dist1);
sum = _mm_add_epi64(sum, dist2);
}
diff --git a/test/av1_k_means_test.cc b/test/av1_k_means_test.cc
index 2ef837c..5b6c22e 100644
--- a/test/av1_k_means_test.cc
+++ b/test/av1_k_means_test.cc
@@ -94,10 +94,11 @@
const int w = block_size_wide[bsize];
const int h = block_size_high[bsize];
const int n = w * h;
- av1_calc_indices_dim1_c(data_, centroids_, indices1_, /*total_dist=*/nullptr,
- n, k);
- test_impl(data_, centroids_, indices2_, /*total_dist=*/nullptr, n, k);
+ int64_t total_dist_dim1, total_dist_impl;
+ av1_calc_indices_dim1_c(data_, centroids_, indices1_, &total_dist_dim1, n, k);
+ test_impl(data_, centroids_, indices2_, &total_dist_impl, n, k);
+ ASSERT_EQ(total_dist_dim1, total_dist_impl);
ASSERT_EQ(CheckResult(n), true)
<< " block " << bsize << " index " << n << " Centroids " << k;
}
@@ -203,10 +204,11 @@
const int w = block_size_wide[bsize];
const int h = block_size_high[bsize];
const int n = w * h;
- av1_calc_indices_dim2_c(data_, centroids_, indices1_, /*total_dist=*/nullptr,
- n, k);
- test_impl(data_, centroids_, indices2_, /*total_dist=*/nullptr, n, k);
+ int64_t total_dist_dim2, total_dist_impl;
+ av1_calc_indices_dim2_c(data_, centroids_, indices1_, &total_dist_dim2, n, k);
+ test_impl(data_, centroids_, indices2_, &total_dist_impl, n, k);
+ ASSERT_EQ(total_dist_dim2, total_dist_impl);
ASSERT_EQ(CheckResult(n), true)
<< " block " << bsize << " index " << n << " Centroids " << k;
}
@@ -225,7 +227,7 @@
aom_usec_timer_start(&timer);
av1_calc_indices_dim2_func func = funcs[i];
for (int j = 0; j < num_loops; ++j) {
- func(data_, centroids_, indices1_, nullptr, n, k);
+ func(data_, centroids_, indices1_, /*total_dist=*/nullptr, n, k);
}
aom_usec_timer_mark(&timer);
double time = static_cast<double>(aom_usec_timer_elapsed(&timer));