David Barker | ee67432 | 2017-05-10 15:43:02 +0100 | [diff] [blame] | 1 | #include <stdlib.h> |
| 2 | #include <memory.h> |
| 3 | #include <math.h> |
| 4 | #include <assert.h> |
| 5 | |
| 6 | #include <smmintrin.h> |
| 7 | |
| 8 | #include "./av1_rtcd.h" |
| 9 | #include "aom_ports/mem.h" |
| 10 | #include "av1/encoder/corner_match.h" |
| 11 | |
| 12 | DECLARE_ALIGNED(16, static const uint8_t, byte_mask[16]) = { |
| 13 | 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 0, 0 |
| 14 | }; |
| 15 | #if MATCH_SZ != 13 |
| 16 | #error "Need to change byte_mask in corner_match_sse4.c if MATCH_SZ != 13" |
| 17 | #endif |
| 18 | |
| 19 | /* Compute corr(im1, im2) * MATCH_SZ * stddev(im1), where the |
| 20 | correlation/standard deviation are taken over MATCH_SZ by MATCH_SZ windows |
| 21 | of each image, centered at (x1, y1) and (x2, y2) respectively. |
| 22 | */ |
| 23 | double compute_cross_correlation_sse4_1(unsigned char *im1, int stride1, int x1, |
| 24 | int y1, unsigned char *im2, int stride2, |
| 25 | int x2, int y2) { |
| 26 | int i; |
| 27 | // 2 16-bit partial sums in lanes 0, 4 (== 2 32-bit partial sums in lanes 0, |
| 28 | // 2) |
| 29 | __m128i sum1_vec = _mm_setzero_si128(); |
| 30 | __m128i sum2_vec = _mm_setzero_si128(); |
| 31 | // 4 32-bit partial sums of squares |
| 32 | __m128i sumsq2_vec = _mm_setzero_si128(); |
| 33 | __m128i cross_vec = _mm_setzero_si128(); |
| 34 | |
| 35 | const __m128i mask = _mm_load_si128((__m128i *)byte_mask); |
| 36 | const __m128i zero = _mm_setzero_si128(); |
| 37 | |
| 38 | im1 += (y1 - MATCH_SZ_BY2) * stride1 + (x1 - MATCH_SZ_BY2); |
| 39 | im2 += (y2 - MATCH_SZ_BY2) * stride2 + (x2 - MATCH_SZ_BY2); |
| 40 | |
| 41 | for (i = 0; i < MATCH_SZ; ++i) { |
| 42 | const __m128i v1 = |
| 43 | _mm_and_si128(_mm_loadu_si128((__m128i *)&im1[i * stride1]), mask); |
| 44 | const __m128i v2 = |
| 45 | _mm_and_si128(_mm_loadu_si128((__m128i *)&im2[i * stride2]), mask); |
| 46 | |
| 47 | // Using the 'sad' intrinsic here is a bit faster than adding |
| 48 | // v1_l + v1_r and v2_l + v2_r, plus it avoids the need for a 16->32 bit |
| 49 | // conversion step later, for a net speedup of ~10% |
| 50 | sum1_vec = _mm_add_epi16(sum1_vec, _mm_sad_epu8(v1, zero)); |
| 51 | sum2_vec = _mm_add_epi16(sum2_vec, _mm_sad_epu8(v2, zero)); |
| 52 | |
| 53 | const __m128i v1_l = _mm_cvtepu8_epi16(v1); |
| 54 | const __m128i v1_r = _mm_cvtepu8_epi16(_mm_srli_si128(v1, 8)); |
| 55 | const __m128i v2_l = _mm_cvtepu8_epi16(v2); |
| 56 | const __m128i v2_r = _mm_cvtepu8_epi16(_mm_srli_si128(v2, 8)); |
| 57 | |
| 58 | sumsq2_vec = _mm_add_epi32( |
| 59 | sumsq2_vec, |
| 60 | _mm_add_epi32(_mm_madd_epi16(v2_l, v2_l), _mm_madd_epi16(v2_r, v2_r))); |
| 61 | cross_vec = _mm_add_epi32( |
| 62 | cross_vec, |
| 63 | _mm_add_epi32(_mm_madd_epi16(v1_l, v2_l), _mm_madd_epi16(v1_r, v2_r))); |
| 64 | } |
| 65 | |
| 66 | // Now we can treat the four registers (sum1_vec, sum2_vec, sumsq2_vec, |
| 67 | // cross_vec) |
| 68 | // as holding 4 32-bit elements each, which we want to sum horizontally. |
| 69 | // We do this by transposing and then summing vertically. |
| 70 | __m128i tmp_0 = _mm_unpacklo_epi32(sum1_vec, sum2_vec); |
| 71 | __m128i tmp_1 = _mm_unpackhi_epi32(sum1_vec, sum2_vec); |
| 72 | __m128i tmp_2 = _mm_unpacklo_epi32(sumsq2_vec, cross_vec); |
| 73 | __m128i tmp_3 = _mm_unpackhi_epi32(sumsq2_vec, cross_vec); |
| 74 | |
| 75 | __m128i tmp_4 = _mm_unpacklo_epi64(tmp_0, tmp_2); |
| 76 | __m128i tmp_5 = _mm_unpackhi_epi64(tmp_0, tmp_2); |
| 77 | __m128i tmp_6 = _mm_unpacklo_epi64(tmp_1, tmp_3); |
| 78 | __m128i tmp_7 = _mm_unpackhi_epi64(tmp_1, tmp_3); |
| 79 | |
| 80 | __m128i res = |
| 81 | _mm_add_epi32(_mm_add_epi32(tmp_4, tmp_5), _mm_add_epi32(tmp_6, tmp_7)); |
| 82 | |
| 83 | int sum1 = _mm_extract_epi32(res, 0); |
| 84 | int sum2 = _mm_extract_epi32(res, 1); |
| 85 | int sumsq2 = _mm_extract_epi32(res, 2); |
| 86 | int cross = _mm_extract_epi32(res, 3); |
| 87 | |
| 88 | int var2 = sumsq2 * MATCH_SZ_SQ - sum2 * sum2; |
| 89 | int cov = cross * MATCH_SZ_SQ - sum1 * sum2; |
| 90 | return cov / sqrt((double)var2); |
| 91 | } |