K-means: switch pipeline to 16 bit.
Less RAM, less for loops in SIMD, simpler SIMD instructions.
On rtc_screen, speed 0, 50 runs, speed-ups are in %:
screen_recording_crd.1920_1080.y4m 0.483
screenshare_buganizer.1900_1306.y4m 0.718
screenshare_colorslides.1820_1320.y4m 0.913
screenshare_slidechanges.1850_1110.y4m 0.576
screenshare_youtube.1680_1178.y4m 0.465
slides_webplot.1920_1080.y4m 0.475
sc_web_browsing720p.y4m 0.690
screen_crd_colwinscroll.1920_1128.y4m 0.146
{OVERALL} 0.558
At speed 10:
screen_recording_crd.1920_1080.y4m 0.347
screenshare_buganizer.1900_1306.y4m 1.145
screenshare_colorslides.1820_1320.y4m 1.028
screenshare_slidechanges.1850_1110.y4m 0.945
screenshare_youtube.1680_1178.y4m 0.366
slides_webplot.1920_1080.y4m 0.999
sc_web_browsing720p.y4m 1.261
screen_crd_colwinscroll.1920_1128.y4m 0.133
{OVERALL} 0.778
Change-Id: I009d31e749c54474b56ff36aec520f686d728f9e
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index b5baaae..ba1dcbb 100644
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -412,10 +412,10 @@
add_proto qw/void av1_quantize_b/, "const tran_low_t *coeff_ptr, intptr_t n_coeffs, const int16_t *zbin_ptr, const int16_t *round_ptr, const int16_t *quant_ptr, const int16_t *quant_shift_ptr, tran_low_t *qcoeff_ptr, tran_low_t *dqcoeff_ptr, const int16_t *dequant_ptr, uint16_t *eob_ptr, const int16_t *scan, const int16_t *iscan, const qm_val_t * qm_ptr, const qm_val_t * iqm_ptr, int log_scale";
- add_proto qw/void av1_calc_indices_dim1/, "const int *data, const int *centroids, uint8_t *indices, int64_t *total_dist, int n, int k";
+ add_proto qw/void av1_calc_indices_dim1/, "const int16_t *data, const int16_t *centroids, uint8_t *indices, int64_t *total_dist, int n, int k";
specialize qw/av1_calc_indices_dim1 sse2 avx2/;
- add_proto qw/void av1_calc_indices_dim2/, "const int *data, const int *centroids, uint8_t *indices, int64_t *total_dist, int n, int k";
+ add_proto qw/void av1_calc_indices_dim2/, "const int16_t *data, const int16_t *centroids, uint8_t *indices, int64_t *total_dist, int n, int k";
specialize qw/av1_calc_indices_dim2 sse2 avx2/;
# ENCODEMB INVOKE
diff --git a/av1/encoder/block.h b/av1/encoder/block.h
index 3dec881..4185798 100644
--- a/av1/encoder/block.h
+++ b/av1/encoder/block.h
@@ -326,7 +326,7 @@
//! The best color map found.
uint8_t best_palette_color_map[MAX_PALETTE_SQUARE];
//! A temporary buffer used for k-means clustering.
- int kmeans_data_buf[2 * MAX_PALETTE_SQUARE];
+ int16_t kmeans_data_buf[2 * MAX_PALETTE_SQUARE];
} PALETTE_BUFFER;
/*! \brief Contains buffers used by av1_compound_type_rd()
diff --git a/av1/encoder/k_means_template.h b/av1/encoder/k_means_template.h
index 0cf72b7..31ffdcf 100644
--- a/av1/encoder/k_means_template.h
+++ b/av1/encoder/k_means_template.h
@@ -28,7 +28,7 @@
// Though we want to compute the smallest L2 norm, in 1 dimension,
// it is equivalent to find the smallest L1 norm and then square it.
// This is preferrable for speed, especially on the SIMD side.
-static int RENAME(calc_dist)(const int *p1, const int *p2) {
+static int RENAME(calc_dist)(const int16_t *p1, const int16_t *p2) {
#if AV1_K_MEANS_DIM == 1
return abs(p1[0] - p2[0]);
#else
@@ -41,7 +41,7 @@
#endif
}
-void RENAME(av1_calc_indices)(const int *data, const int *centroids,
+void RENAME(av1_calc_indices)(const int16_t *data, const int16_t *centroids,
uint8_t *indices, int64_t *dist, int n, int k) {
if (dist) {
*dist = 0;
@@ -67,20 +67,22 @@
}
}
-static void RENAME(calc_centroids)(const int *data, int *centroids,
+static void RENAME(calc_centroids)(const int16_t *data, int16_t *centroids,
const uint8_t *indices, int n, int k) {
int i, j;
int count[PALETTE_MAX_SIZE] = { 0 };
+ int centroids_sum[AV1_K_MEANS_DIM * PALETTE_MAX_SIZE];
unsigned int rand_state = (unsigned int)data[0];
assert(n <= 32768);
- memset(centroids, 0, sizeof(centroids[0]) * k * AV1_K_MEANS_DIM);
+ memset(centroids_sum, 0, sizeof(centroids_sum[0]) * k * AV1_K_MEANS_DIM);
for (i = 0; i < n; ++i) {
const int index = indices[i];
assert(index < k);
++count[index];
for (j = 0; j < AV1_K_MEANS_DIM; ++j) {
- centroids[index * AV1_K_MEANS_DIM + j] += data[i * AV1_K_MEANS_DIM + j];
+ centroids_sum[index * AV1_K_MEANS_DIM + j] +=
+ data[i * AV1_K_MEANS_DIM + j];
}
}
@@ -92,17 +94,17 @@
} else {
for (j = 0; j < AV1_K_MEANS_DIM; ++j) {
centroids[i * AV1_K_MEANS_DIM + j] =
- DIVIDE_AND_ROUND(centroids[i * AV1_K_MEANS_DIM + j], count[i]);
+ DIVIDE_AND_ROUND(centroids_sum[i * AV1_K_MEANS_DIM + j], count[i]);
}
}
}
}
-void RENAME(av1_k_means)(const int *data, int *centroids, uint8_t *indices,
- int n, int k, int max_itr) {
- int centroids_tmp[AV1_K_MEANS_DIM * PALETTE_MAX_SIZE];
+void RENAME(av1_k_means)(const int16_t *data, int16_t *centroids,
+ uint8_t *indices, int n, int k, int max_itr) {
+ int16_t centroids_tmp[AV1_K_MEANS_DIM * PALETTE_MAX_SIZE];
uint8_t indices_tmp[MAX_PALETTE_BLOCK_WIDTH * MAX_PALETTE_BLOCK_HEIGHT];
- int *meta_centroids[2] = { centroids, centroids_tmp };
+ int16_t *meta_centroids[2] = { centroids, centroids_tmp };
uint8_t *meta_indices[2] = { indices, indices_tmp };
int i, l = 0, prev_l, best_l = 0;
int64_t this_dist;
diff --git a/av1/encoder/palette.c b/av1/encoder/palette.c
index 4375175..9c3d407 100644
--- a/av1/encoder/palette.c
+++ b/av1/encoder/palette.c
@@ -31,14 +31,14 @@
#include "av1/encoder/k_means_template.h"
#undef AV1_K_MEANS_DIM
-static int int_comparer(const void *a, const void *b) {
- return (*(int *)a - *(int *)b);
+static int int16_comparer(const void *a, const void *b) {
+ return (*(int16_t *)a - *(int16_t *)b);
}
-int av1_remove_duplicates(int *centroids, int num_centroids) {
+int av1_remove_duplicates(int16_t *centroids, int num_centroids) {
int num_unique; // number of unique centroids
int i;
- qsort(centroids, num_centroids, sizeof(*centroids), int_comparer);
+ qsort(centroids, num_centroids, sizeof(*centroids), int16_comparer);
// Remove duplicates.
num_unique = 1;
for (i = 1; i < num_centroids; ++i) {
@@ -189,14 +189,14 @@
// TODO(huisu): Try other schemes to improve compression.
static AOM_INLINE void optimize_palette_colors(uint16_t *color_cache,
int n_cache, int n_colors,
- int stride, int *centroids,
+ int stride, int16_t *centroids,
int bit_depth) {
if (n_cache <= 0) return;
for (int i = 0; i < n_colors * stride; i += stride) {
- int min_diff = abs(centroids[i] - (int)color_cache[0]);
+ int min_diff = abs((int)centroids[i] - (int)color_cache[0]);
int idx = 0;
for (int j = 1; j < n_cache; ++j) {
- const int this_diff = abs(centroids[i] - color_cache[j]);
+ const int this_diff = abs((int)centroids[i] - (int)color_cache[j]);
if (this_diff < min_diff) {
min_diff = this_diff;
idx = j;
@@ -216,8 +216,8 @@
*/
static AOM_INLINE void palette_rd_y(
const AV1_COMP *const cpi, MACROBLOCK *x, MB_MODE_INFO *mbmi,
- BLOCK_SIZE bsize, int dc_mode_cost, const int *data, int *centroids, int n,
- uint16_t *color_cache, int n_cache, bool do_header_rd_based_gating,
+ BLOCK_SIZE bsize, int dc_mode_cost, const int16_t *data, int16_t *centroids,
+ int n, uint16_t *color_cache, int n_cache, bool do_header_rd_based_gating,
MB_MODE_INFO *best_mbmi, uint8_t *best_palette_color_map, int64_t *best_rd,
int *rate, int *rate_tokenonly, int64_t *distortion, uint8_t *skippable,
int *beat_best_rd, PICK_MODE_CONTEXT *ctx, uint8_t *blk_skip,
@@ -324,14 +324,14 @@
// returns the best number of colors found.
static AOM_INLINE int perform_top_color_palette_search(
const AV1_COMP *const cpi, MACROBLOCK *x, MB_MODE_INFO *mbmi,
- BLOCK_SIZE bsize, int dc_mode_cost, const int *data, int *top_colors,
- int start_n, int end_n, int step_size, bool do_header_rd_based_gating,
- int *last_n_searched, uint16_t *color_cache, int n_cache,
- MB_MODE_INFO *best_mbmi, uint8_t *best_palette_color_map, int64_t *best_rd,
- int *rate, int *rate_tokenonly, int64_t *distortion, uint8_t *skippable,
- int *beat_best_rd, PICK_MODE_CONTEXT *ctx, uint8_t *best_blk_skip,
- uint8_t *tx_type_map, int discount_color_cost) {
- int centroids[PALETTE_MAX_SIZE];
+ BLOCK_SIZE bsize, int dc_mode_cost, const int16_t *data,
+ int16_t *top_colors, int start_n, int end_n, int step_size,
+ bool do_header_rd_based_gating, int *last_n_searched, uint16_t *color_cache,
+ int n_cache, MB_MODE_INFO *best_mbmi, uint8_t *best_palette_color_map,
+ int64_t *best_rd, int *rate, int *rate_tokenonly, int64_t *distortion,
+ uint8_t *skippable, int *beat_best_rd, PICK_MODE_CONTEXT *ctx,
+ uint8_t *best_blk_skip, uint8_t *tx_type_map, int discount_color_cost) {
+ int16_t centroids[PALETTE_MAX_SIZE];
int n = start_n;
int top_color_winner = end_n;
/* clang-format off */
@@ -371,7 +371,7 @@
// returns the best number of colors found.
static AOM_INLINE int perform_k_means_palette_search(
const AV1_COMP *const cpi, MACROBLOCK *x, MB_MODE_INFO *mbmi,
- BLOCK_SIZE bsize, int dc_mode_cost, const int *data, int lower_bound,
+ BLOCK_SIZE bsize, int dc_mode_cost, const int16_t *data, int lower_bound,
int upper_bound, int start_n, int end_n, int step_size,
bool do_header_rd_based_gating, int *last_n_searched, uint16_t *color_cache,
int n_cache, MB_MODE_INFO *best_mbmi, uint8_t *best_palette_color_map,
@@ -379,7 +379,7 @@
uint8_t *skippable, int *beat_best_rd, PICK_MODE_CONTEXT *ctx,
uint8_t *best_blk_skip, uint8_t *tx_type_map, uint8_t *color_map,
int data_points, int discount_color_cost) {
- int centroids[PALETTE_MAX_SIZE];
+ int16_t centroids[PALETTE_MAX_SIZE];
const int max_itr = 50;
int n = start_n;
int top_color_winner = end_n;
@@ -435,16 +435,19 @@
*step_size = AOMMAX(1, *max_n - *min_n);
}
-static AOM_INLINE void fill_data_and_get_bounds(
- const uint8_t *src, const int src_stride, const int rows, const int cols,
- const int is_high_bitdepth, int *data, int *lower_bound, int *upper_bound) {
+static AOM_INLINE void fill_data_and_get_bounds(const uint8_t *src,
+ const int src_stride,
+ const int rows, const int cols,
+ const int is_high_bitdepth,
+ int16_t *data, int *lower_bound,
+ int *upper_bound) {
if (is_high_bitdepth) {
const uint16_t *src_ptr = CONVERT_TO_SHORTPTR(src);
*lower_bound = *upper_bound = src_ptr[0];
for (int r = 0; r < rows; ++r) {
for (int c = 0; c < cols; ++c) {
const int val = src_ptr[c];
- data[c] = val;
+ data[c] = (int16_t)val;
*lower_bound = AOMMIN(*lower_bound, val);
*upper_bound = AOMMAX(*upper_bound, val);
}
@@ -459,7 +462,7 @@
for (int r = 0; r < rows; ++r) {
for (int c = 0; c < cols; ++c) {
const int val = src[c];
- data[c] = val;
+ data[c] = (int16_t)val;
*lower_bound = AOMMIN(*lower_bound, val);
*upper_bound = AOMMAX(*upper_bound, val);
}
@@ -487,7 +490,7 @@
}
static void find_top_colors(const int *const count_buf, int bit_depth,
- int n_colors, int *top_colors) {
+ int n_colors, int16_t *top_colors) {
// Top color array, serving as a priority queue if more than n_colors are
// found.
struct ColorCount top_color_counts[PALETTE_MAX_SIZE] = { { 0 } };
@@ -562,8 +565,8 @@
uint8_t *const color_map = xd->plane[0].color_index_map;
if (colors_threshold > 1 && colors_threshold <= 64) {
- int *const data = x->palette_buffer->kmeans_data_buf;
- int centroids[PALETTE_MAX_SIZE];
+ int16_t *const data = x->palette_buffer->kmeans_data_buf;
+ int16_t centroids[PALETTE_MAX_SIZE];
int lower_bound, upper_bound;
fill_data_and_get_bounds(src, src_stride, rows, cols, is_hbd, data,
&lower_bound, &upper_bound);
@@ -575,7 +578,7 @@
const int n_cache = av1_get_palette_cache(xd, 0, color_cache);
// Find the dominant colors, stored in top_colors[].
- int top_colors[PALETTE_MAX_SIZE] = { 0 };
+ int16_t top_colors[PALETTE_MAX_SIZE] = { 0 };
find_top_colors(count_buf, bit_depth, AOMMIN(colors, PALETTE_MAX_SIZE),
top_colors);
@@ -791,8 +794,8 @@
const int max_itr = 50;
int lb_u, ub_u, val_u;
int lb_v, ub_v, val_v;
- int *const data = x->palette_buffer->kmeans_data_buf;
- int centroids[2 * PALETTE_MAX_SIZE];
+ int16_t *const data = x->palette_buffer->kmeans_data_buf;
+ int16_t centroids[2 * PALETTE_MAX_SIZE];
uint16_t *src_u16 = CONVERT_TO_SHORTPTR(src_u);
uint16_t *src_v16 = CONVERT_TO_SHORTPTR(src_v);
@@ -917,8 +920,8 @@
int src_stride = x->plane[1].src.stride;
const uint8_t *const src_u = x->plane[1].src.buf;
const uint8_t *const src_v = x->plane[2].src.buf;
- int *const data = x->palette_buffer->kmeans_data_buf;
- int centroids[2 * PALETTE_MAX_SIZE];
+ int16_t *const data = x->palette_buffer->kmeans_data_buf;
+ int16_t centroids[2 * PALETTE_MAX_SIZE];
uint8_t *const color_map = xd->plane[1].color_index_map;
int r, c;
const uint16_t *const src_u16 = CONVERT_TO_SHORTPTR(src_u);
diff --git a/av1/encoder/palette.h b/av1/encoder/palette.h
index 6f33f44..7da863a 100644
--- a/av1/encoder/palette.h
+++ b/av1/encoder/palette.h
@@ -28,10 +28,10 @@
/*!\cond */
#define AV1_K_MEANS_RENAME(func, dim) func##_dim##dim##_c
-void AV1_K_MEANS_RENAME(av1_k_means, 1)(const int *data, int *centroids,
+void AV1_K_MEANS_RENAME(av1_k_means, 1)(const int16_t *data, int16_t *centroids,
uint8_t *indices, int n, int k,
int max_itr);
-void AV1_K_MEANS_RENAME(av1_k_means, 2)(const int *data, int *centroids,
+void AV1_K_MEANS_RENAME(av1_k_means, 2)(const int16_t *data, int16_t *centroids,
uint8_t *indices, int n, int k,
int max_itr);
/*!\endcond */
@@ -51,8 +51,9 @@
*
* \remark Returns nothing, but saves each data's cluster index in \a indices.
*/
-static INLINE void av1_calc_indices(const int *data, const int *centroids,
- uint8_t *indices, int n, int k, int dim) {
+static INLINE void av1_calc_indices(const int16_t *data,
+ const int16_t *centroids, uint8_t *indices,
+ int n, int k, int dim) {
assert(n > 0);
assert(k > 0);
if (dim == 1) {
@@ -84,7 +85,7 @@
*
* \attention The output centroids are rounded off to nearest integers.
*/
-static INLINE void av1_k_means(const int *data, int *centroids,
+static INLINE void av1_k_means(const int16_t *data, int16_t *centroids,
uint8_t *indices, int n, int k, int dim,
int max_itr) {
assert(n > 0);
@@ -110,7 +111,7 @@
* \attention The centroids should be rounded to integers before calling this
* method.
*/
-int av1_remove_duplicates(int *centroids, int num_centroids);
+int av1_remove_duplicates(int16_t *centroids, int num_centroids);
/*!\brief Checks what colors are in the color cache.
*
diff --git a/av1/encoder/x86/av1_k_means_avx2.c b/av1/encoder/x86/av1_k_means_avx2.c
index 854b90c..a2db222 100644
--- a/av1/encoder/x86/av1_k_means_avx2.c
+++ b/av1/encoder/x86/av1_k_means_avx2.c
@@ -23,7 +23,55 @@
return res;
}
-void av1_calc_indices_dim1_avx2(const int *data, const int *centroids,
+void av1_calc_indices_dim1_avx2(const int16_t *data, const int16_t *centroids,
+ uint8_t *indices, int64_t *total_dist, int n,
+ int k) {
+ __m256i dist[PALETTE_MAX_SIZE];
+ const __m256i v_zero = _mm256_setzero_si256();
+ __m256i sum = _mm256_setzero_si256();
+
+ for (int i = 0; i < n; i += 16) {
+ __m256i ind = _mm256_loadu_si256((__m256i *)data);
+ for (int j = 0; j < k; j++) {
+ __m256i cent = _mm256_set1_epi16(centroids[j]);
+ __m256i d1 = _mm256_sub_epi16(ind, cent);
+ dist[j] = _mm256_abs_epi16(d1);
+ }
+
+ ind = _mm256_setzero_si256();
+ for (int j = 1; j < k; j++) {
+ __m256i cmp = _mm256_cmpgt_epi16(dist[0], dist[j]);
+ dist[0] = _mm256_min_epi16(dist[0], dist[j]);
+ __m256i ind1 = _mm256_set1_epi16(j);
+ ind = _mm256_or_si256(_mm256_andnot_si256(cmp, ind),
+ _mm256_and_si256(cmp, ind1));
+ }
+
+ __m256i p1 = _mm256_packus_epi16(ind, v_zero);
+ __m256i px = _mm256_permute4x64_epi64(p1, 0x58);
+ __m128i d1 = _mm256_extracti128_si256(px, 0);
+
+ _mm_storeu_si128((__m128i *)indices, d1);
+
+ if (total_dist) {
+ // Square, convert to 32 bit and add together.
+ dist[0] = _mm256_madd_epi16(dist[0], dist[0]);
+ // Convert to 64 bit and add to sum.
+ const __m256i dist1 = _mm256_unpacklo_epi32(dist[0], v_zero);
+ const __m256i dist2 = _mm256_unpackhi_epi32(dist[0], v_zero);
+ sum = _mm256_add_epi64(sum, dist1);
+ sum = _mm256_add_epi64(sum, dist2);
+ }
+
+ indices += 16;
+ data += 16;
+ }
+ if (total_dist) {
+ *total_dist = k_means_horizontal_sum_avx2(sum);
+ }
+}
+
+void av1_calc_indices_dim2_avx2(const int16_t *data, const int16_t *centroids,
uint8_t *indices, int64_t *total_dist, int n,
int k) {
__m256i dist[PALETTE_MAX_SIZE];
@@ -33,77 +81,18 @@
for (int i = 0; i < n; i += 8) {
__m256i ind = _mm256_loadu_si256((__m256i *)data);
for (int j = 0; j < k; j++) {
- __m256i cent = _mm256_set1_epi32(centroids[j]);
- __m256i d1 = _mm256_sub_epi32(ind, cent);
- dist[j] = _mm256_abs_epi32(d1);
+ const int16_t cx = centroids[2 * j], cy = centroids[2 * j + 1];
+ const __m256i cent = _mm256_set_epi16(cy, cx, cy, cx, cy, cx, cy, cx, cy,
+ cx, cy, cx, cy, cx, cy, cx);
+ const __m256i d1 = _mm256_sub_epi16(ind, cent);
+ dist[j] = _mm256_madd_epi16(d1, d1);
}
ind = _mm256_setzero_si256();
for (int j = 1; j < k; j++) {
__m256i cmp = _mm256_cmpgt_epi32(dist[0], dist[j]);
- __m256i dist1 = _mm256_andnot_si256(cmp, dist[0]);
- __m256i dist2 = _mm256_and_si256(cmp, dist[j]);
- dist[0] = _mm256_or_si256(dist1, dist2);
- __m256i ind1 = _mm256_set1_epi32(j);
- ind = _mm256_or_si256(_mm256_andnot_si256(cmp, ind),
- _mm256_and_si256(cmp, ind1));
- }
-
- __m256i p1 = _mm256_packus_epi32(ind, v_zero);
- __m256i px = _mm256_permute4x64_epi64(p1, 0x58);
- __m256i p2 = _mm256_packus_epi16(px, v_zero);
- __m128i d1 = _mm256_extracti128_si256(p2, 0);
-
- _mm_storel_epi64((__m128i *)indices, d1);
-
- if (total_dist) {
- // Square, convert to 64 bit and add to sum.
- dist[0] = _mm256_mullo_epi32(dist[0], dist[0]);
- const __m256i dist1 = _mm256_unpacklo_epi32(dist[0], v_zero);
- const __m256i dist2 = _mm256_unpackhi_epi32(dist[0], v_zero);
- sum = _mm256_add_epi64(sum, dist1);
- sum = _mm256_add_epi64(sum, dist2);
- }
-
- indices += 8;
- data += 8;
- }
- if (total_dist) {
- *total_dist = k_means_horizontal_sum_avx2(sum);
- }
-}
-
-void av1_calc_indices_dim2_avx2(const int *data, const int *centroids,
- uint8_t *indices, int64_t *total_dist, int n,
- int k) {
- __m256i dist[PALETTE_MAX_SIZE];
- const __m256i v_zero = _mm256_setzero_si256();
- const __m256i v_permute = _mm256_setr_epi32(0, 1, 4, 5, 2, 3, 6, 7);
- __m256i sum = _mm256_setzero_si256();
-
- for (int i = 0; i < n; i += 8) {
- __m256i ind1 = _mm256_loadu_si256((__m256i *)data);
- __m256i ind2 = _mm256_loadu_si256((__m256i *)(data + 8));
- for (int j = 0; j < k; j++) {
- __m128i cent0 = _mm_loadl_epi64((__m128i const *)¢roids[2 * j]);
- __m256i cent1 = _mm256_inserti128_si256(v_zero, cent0, 0);
- cent1 = _mm256_inserti128_si256(cent1, cent0, 1);
- __m256i cent = _mm256_unpacklo_epi64(cent1, cent1);
- __m256i d1 = _mm256_sub_epi32(ind1, cent);
- __m256i d2 = _mm256_sub_epi32(ind2, cent);
- __m256i d3 = _mm256_mullo_epi32(d1, d1);
- __m256i d4 = _mm256_mullo_epi32(d2, d2);
- __m256i d5 = _mm256_hadd_epi32(d3, d4);
- dist[j] = _mm256_permutevar8x32_epi32(d5, v_permute);
- }
-
- __m256i ind = _mm256_setzero_si256();
- for (int j = 1; j < k; j++) {
- __m256i cmp = _mm256_cmpgt_epi32(dist[0], dist[j]);
- __m256i dist1 = _mm256_andnot_si256(cmp, dist[0]);
- __m256i dist2 = _mm256_and_si256(cmp, dist[j]);
- dist[0] = _mm256_or_si256(dist1, dist2);
- ind1 = _mm256_set1_epi32(j);
+ dist[0] = _mm256_min_epi32(dist[0], dist[j]);
+ const __m256i ind1 = _mm256_set1_epi32(j);
ind = _mm256_or_si256(_mm256_andnot_si256(cmp, ind),
_mm256_and_si256(cmp, ind1));
}
diff --git a/av1/encoder/x86/av1_k_means_sse2.c b/av1/encoder/x86/av1_k_means_sse2.c
index d2c7796..a284fa9 100644
--- a/av1/encoder/x86/av1_k_means_sse2.c
+++ b/av1/encoder/x86/av1_k_means_sse2.c
@@ -22,68 +22,49 @@
return res;
}
-static __m128i absolute_diff_epi32(__m128i a, __m128i b) {
- const __m128i diff1 = _mm_sub_epi32(a, b);
- const __m128i diff2 = _mm_sub_epi32(b, a);
- const __m128i cmp = _mm_cmpgt_epi32(diff1, diff2);
- const __m128i masked1 = _mm_and_si128(cmp, diff1);
- const __m128i masked2 = _mm_andnot_si128(cmp, diff2);
- return _mm_or_si128(masked1, masked2);
-}
-
-void av1_calc_indices_dim1_sse2(const int *data, const int *centroids,
+void av1_calc_indices_dim1_sse2(const int16_t *data, const int16_t *centroids,
uint8_t *indices, int64_t *total_dist, int n,
int k) {
const __m128i v_zero = _mm_setzero_si128();
- int l = 1;
__m128i dist[PALETTE_MAX_SIZE];
- __m128i ind[2];
__m128i sum = _mm_setzero_si128();
- for (int i = 0; i < n; i += 4) {
- l = (l == 0) ? 1 : 0;
- ind[l] = _mm_loadu_si128((__m128i *)data);
+ for (int i = 0; i < n; i += 8) {
+ __m128i in = _mm_loadu_si128((__m128i *)data);
for (int j = 0; j < k; j++) {
- __m128i cent = _mm_set1_epi32(centroids[j]);
- dist[j] = absolute_diff_epi32(ind[l], cent);
+ __m128i cent = _mm_set1_epi16(centroids[j]);
+ __m128i d1 = _mm_sub_epi16(in, cent);
+ __m128i d2 = _mm_sub_epi16(cent, in);
+ dist[j] = _mm_max_epi16(d1, d2);
}
- ind[l] = _mm_setzero_si128();
+ __m128i ind = _mm_setzero_si128();
for (int j = 1; j < k; j++) {
- __m128i cmp = _mm_cmpgt_epi32(dist[0], dist[j]);
- __m128i dist1 = _mm_andnot_si128(cmp, dist[0]);
- __m128i dist2 = _mm_and_si128(cmp, dist[j]);
- dist[0] = _mm_or_si128(dist1, dist2);
- __m128i ind1 = _mm_set1_epi32(j);
- ind[l] =
- _mm_or_si128(_mm_andnot_si128(cmp, ind[l]), _mm_and_si128(cmp, ind1));
+ __m128i cmp = _mm_cmpgt_epi16(dist[0], dist[j]);
+ dist[0] = _mm_min_epi16(dist[0], dist[j]);
+ __m128i ind1 = _mm_set1_epi16(j);
+ ind = _mm_or_si128(_mm_andnot_si128(cmp, ind), _mm_and_si128(cmp, ind1));
}
- ind[l] = _mm_packus_epi16(ind[l], v_zero);
if (total_dist) {
- // Square and convert to 32 bit.
- const __m128i d1 = _mm_packs_epi32(dist[0], v_zero);
- const __m128i d2 = _mm_mullo_epi16(d1, d1);
- const __m128i d3 = _mm_mulhi_epi16(d1, d1);
- dist[0] = _mm_unpacklo_epi16(d2, d3);
+ // Square, convert to 32 bit and add together.
+ dist[0] = _mm_madd_epi16(dist[0], dist[0]);
// Convert to 64 bit and add to sum.
const __m128i dist1 = _mm_unpacklo_epi32(dist[0], v_zero);
const __m128i dist2 = _mm_unpackhi_epi32(dist[0], v_zero);
sum = _mm_add_epi64(sum, dist1);
sum = _mm_add_epi64(sum, dist2);
}
- if (l == 1) {
- __m128i p2 = _mm_packus_epi16(_mm_unpacklo_epi64(ind[0], ind[1]), v_zero);
- _mm_storel_epi64((__m128i *)indices, p2);
- indices += 8;
- }
- data += 4;
+ __m128i p2 = _mm_packus_epi16(ind, v_zero);
+ _mm_storel_epi64((__m128i *)indices, p2);
+ indices += 8;
+ data += 8;
}
if (total_dist) {
*total_dist = k_means_horizontal_sum_sse2(sum);
}
}
-void av1_calc_indices_dim2_sse2(const int *data, const int *centroids,
+void av1_calc_indices_dim2_sse2(const int16_t *data, const int16_t *centroids,
uint8_t *indices, int64_t *total_dist, int n,
int k) {
const __m128i v_zero = _mm_setzero_si128();
@@ -95,19 +76,11 @@
for (int i = 0; i < n; i += 4) {
l = (l == 0) ? 1 : 0;
__m128i ind1 = _mm_loadu_si128((__m128i *)data);
- __m128i ind2 = _mm_loadu_si128((__m128i *)(data + 4));
- __m128i indl = _mm_unpacklo_epi32(ind1, ind2);
- __m128i indh = _mm_unpackhi_epi32(ind1, ind2);
- ind1 = _mm_unpacklo_epi32(indl, indh);
- ind2 = _mm_unpackhi_epi32(indl, indh);
for (int j = 0; j < k; j++) {
- __m128i cent0 = _mm_set1_epi32(centroids[2 * j]);
- __m128i cent1 = _mm_set1_epi32(centroids[2 * j + 1]);
- __m128i d1 = absolute_diff_epi32(ind1, cent0);
- __m128i d2 = absolute_diff_epi32(ind2, cent1);
- __m128i d3 = _mm_madd_epi16(d1, d1);
- __m128i d4 = _mm_madd_epi16(d2, d2);
- dist[j] = _mm_add_epi32(d3, d4);
+ const int16_t cx = centroids[2 * j], cy = centroids[2 * j + 1];
+ const __m128i cent = _mm_set_epi16(cy, cx, cy, cx, cy, cx, cy, cx);
+ const __m128i d1 = _mm_sub_epi16(ind1, cent);
+ dist[j] = _mm_madd_epi16(d1, d1);
}
ind[l] = _mm_setzero_si128();
diff --git a/test/av1_k_means_test.cc b/test/av1_k_means_test.cc
index 5b6c22e..221dd10 100644
--- a/test/av1_k_means_test.cc
+++ b/test/av1_k_means_test.cc
@@ -28,12 +28,12 @@
#include "third_party/googletest/src/googletest/include/gtest/gtest.h"
namespace AV1Kmeans {
-typedef void (*av1_calc_indices_dim1_func)(const int *data,
- const int *centroids,
+typedef void (*av1_calc_indices_dim1_func)(const int16_t *data,
+ const int16_t *centroids,
uint8_t *indices,
int64_t *total_dist, int n, int k);
-typedef void (*av1_calc_indices_dim2_func)(const int *data,
- const int *centroids,
+typedef void (*av1_calc_indices_dim2_func)(const int16_t *data,
+ const int16_t *centroids,
uint8_t *indices,
int64_t *total_dist, int n, int k);
@@ -68,8 +68,8 @@
}
libaom_test::ACMRandom rnd_;
- int data_[4096];
- int centroids_[8];
+ int16_t data_[4096];
+ int16_t centroids_[8];
uint8_t indices1_[4096];
uint8_t indices2_[4096];
};
@@ -178,8 +178,8 @@
}
libaom_test::ACMRandom rnd_;
- int data_[4096 * 2];
- int centroids_[8 * 2];
+ int16_t data_[4096 * 2];
+ int16_t centroids_[8 * 2];
uint8_t indices1_[4096];
uint8_t indices2_[4096];
};