Fix aom_<highbd>_sse_sse4_1/avx2 issue 1. Add missing default branch in aom_sse_sse4_1/avx2, which covers cases where width not match any BLOCK_SIZE, such as 12, 24, etc. 2. Fix overflow issue for aom_highbd_sse_sse4_1 3. Add unit tests for aom_<highbd>_sse_sse4_1/avx2 Change-Id: I22a84a44bab585eb0f78afa4a1b7cdb44adc6ede
diff --git a/aom_dsp/x86/sse_avx2.c b/aom_dsp/x86/sse_avx2.c index 305dde5..fa45687 100644 --- a/aom_dsp/x86/sse_avx2.c +++ b/aom_dsp/x86/sse_avx2.c
@@ -48,6 +48,57 @@ return sum; } +static INLINE void summary_32_avx2(const __m256i *sum32, __m256i *sum) { + const __m256i sum0_4x64 = + _mm256_cvtepu32_epi64(_mm256_castsi256_si128(*sum32)); + const __m256i sum1_4x64 = + _mm256_cvtepu32_epi64(_mm256_extracti128_si256(*sum32, 1)); + const __m256i sum_4x64 = _mm256_add_epi64(sum0_4x64, sum1_4x64); + *sum = _mm256_add_epi64(*sum, sum_4x64); +} + +static INLINE int64_t summary_4x64_avx2(const __m256i sum_4x64) { + int64_t sum; + const __m128i sum_2x64 = _mm_add_epi64(_mm256_castsi256_si128(sum_4x64), + _mm256_extracti128_si256(sum_4x64, 1)); + const __m128i sum_1x64 = _mm_add_epi64(sum_2x64, _mm_srli_si128(sum_2x64, 8)); + + xx_storel_64(&sum, sum_1x64); + return sum; +} + +static INLINE void sse_w4x4_avx2(const uint8_t *a, int a_stride, + const uint8_t *b, int b_stride, __m256i *sum) { + const __m128i v_a0 = xx_loadl_32(a); + const __m128i v_a1 = xx_loadl_32(a + a_stride); + const __m128i v_a2 = xx_loadl_32(a + a_stride * 2); + const __m128i v_a3 = xx_loadl_32(a + a_stride * 3); + const __m128i v_b0 = xx_loadl_32(b); + const __m128i v_b1 = xx_loadl_32(b + b_stride); + const __m128i v_b2 = xx_loadl_32(b + b_stride * 2); + const __m128i v_b3 = xx_loadl_32(b + b_stride * 3); + const __m128i v_a0123 = _mm_unpacklo_epi64(_mm_unpacklo_epi32(v_a0, v_a1), + _mm_unpacklo_epi32(v_a2, v_a3)); + const __m128i v_b0123 = _mm_unpacklo_epi64(_mm_unpacklo_epi32(v_b0, v_b1), + _mm_unpacklo_epi32(v_b2, v_b3)); + const __m256i v_a_w = _mm256_cvtepu8_epi16(v_a0123); + const __m256i v_b_w = _mm256_cvtepu8_epi16(v_b0123); + const __m256i v_d_w = _mm256_sub_epi16(v_a_w, v_b_w); + *sum = _mm256_add_epi32(*sum, _mm256_madd_epi16(v_d_w, v_d_w)); +} + +static INLINE void sse_w8x2_avx2(const uint8_t *a, int a_stride, + const uint8_t *b, int b_stride, __m256i *sum) { + const __m128i v_a0 = xx_loadl_64(a); + const __m128i v_a1 = xx_loadl_64(a + a_stride); + const __m128i v_b0 = xx_loadl_64(b); + const __m128i v_b1 = xx_loadl_64(b + b_stride); + const __m256i v_a_w = _mm256_cvtepu8_epi16(_mm_unpacklo_epi64(v_a0, v_a1)); + const __m256i v_b_w = _mm256_cvtepu8_epi16(_mm_unpacklo_epi64(v_b0, v_b1)); + const __m256i v_d_w = _mm256_sub_epi16(v_a_w, v_b_w); + *sum = _mm256_add_epi32(*sum, _mm256_madd_epi16(v_d_w, v_d_w)); +} + int64_t aom_sse_avx2(const uint8_t *a, int a_stride, const uint8_t *b, int b_stride, int width, int height) { int32_t y = 0; @@ -56,22 +107,7 @@ switch (width) { case 4: do { - const __m128i v_a0 = xx_loadl_32(a); - const __m128i v_a1 = xx_loadl_32(a + a_stride); - const __m128i v_a2 = xx_loadl_32(a + a_stride * 2); - const __m128i v_a3 = xx_loadl_32(a + a_stride * 3); - const __m128i v_b0 = xx_loadl_32(b); - const __m128i v_b1 = xx_loadl_32(b + b_stride); - const __m128i v_b2 = xx_loadl_32(b + b_stride * 2); - const __m128i v_b3 = xx_loadl_32(b + b_stride * 3); - const __m128i v_a0123 = _mm_unpacklo_epi64( - _mm_unpacklo_epi32(v_a0, v_a1), _mm_unpacklo_epi32(v_a2, v_a3)); - const __m128i v_b0123 = _mm_unpacklo_epi64( - _mm_unpacklo_epi32(v_b0, v_b1), _mm_unpacklo_epi32(v_b2, v_b3)); - const __m256i v_a_w = _mm256_cvtepu8_epi16(v_a0123); - const __m256i v_b_w = _mm256_cvtepu8_epi16(v_b0123); - const __m256i v_d_w = _mm256_sub_epi16(v_a_w, v_b_w); - sum = _mm256_add_epi32(sum, _mm256_madd_epi16(v_d_w, v_d_w)); + sse_w4x4_avx2(a, a_stride, b, b_stride, &sum); a += a_stride << 2; b += b_stride << 2; y += 4; @@ -80,16 +116,7 @@ break; case 8: do { - const __m128i v_a0 = xx_loadl_64(a); - const __m128i v_a1 = xx_loadl_64(a + a_stride); - const __m128i v_b0 = xx_loadl_64(b); - const __m128i v_b1 = xx_loadl_64(b + b_stride); - const __m256i v_a_w = - _mm256_cvtepu8_epi16(_mm_unpacklo_epi64(v_a0, v_a1)); - const __m256i v_b_w = - _mm256_cvtepu8_epi16(_mm_unpacklo_epi64(v_b0, v_b1)); - const __m256i v_d_w = _mm256_sub_epi16(v_a_w, v_b_w); - sum = _mm256_add_epi32(sum, _mm256_madd_epi16(v_d_w, v_d_w)); + sse_w8x2_avx2(a, a_stride, b, b_stride, &sum); a += a_stride << 1; b += b_stride << 1; y += 2; @@ -141,7 +168,36 @@ } while (y < height); sse = summary_all_avx2(&sum); break; - default: break; + default: + if ((width & 0x07) == 0) { + do { + int i = 0; + do { + sse_w8x2_avx2(a + i, a_stride, b + i, b_stride, &sum); + i += 8; + } while (i < width); + a += a_stride << 1; + b += b_stride << 1; + y += 2; + } while (y < height); + } else { + do { + int i = 0; + do { + sse_w8x2_avx2(a + i, a_stride, b + i, b_stride, &sum); + const uint8_t *a2 = a + i + (a_stride << 1); + const uint8_t *b2 = b + i + (b_stride << 1); + sse_w8x2_avx2(a2, a_stride, b2, b_stride, &sum); + i += 8; + } while (i + 4 < width); + sse_w4x4_avx2(a + i, a_stride, b + i, b_stride, &sum); + a += a_stride << 2; + b += b_stride << 2; + y += 4; + } while (y < height); + } + sse = summary_all_avx2(&sum); + break; } return sse; @@ -155,6 +211,33 @@ *sum = _mm256_add_epi32(*sum, _mm256_madd_epi16(v_d_w, v_d_w)); } +static INLINE void highbd_sse_w4x4_avx2(__m256i *sum, const uint16_t *a, + int a_stride, const uint16_t *b, + int b_stride) { + const __m128i v_a0 = xx_loadl_64(a); + const __m128i v_a1 = xx_loadl_64(a + a_stride); + const __m128i v_a2 = xx_loadl_64(a + a_stride * 2); + const __m128i v_a3 = xx_loadl_64(a + a_stride * 3); + const __m128i v_b0 = xx_loadl_64(b); + const __m128i v_b1 = xx_loadl_64(b + b_stride); + const __m128i v_b2 = xx_loadl_64(b + b_stride * 2); + const __m128i v_b3 = xx_loadl_64(b + b_stride * 3); + const __m256i v_a_w = yy_set_m128i(_mm_unpacklo_epi64(v_a0, v_a1), + _mm_unpacklo_epi64(v_a2, v_a3)); + const __m256i v_b_w = yy_set_m128i(_mm_unpacklo_epi64(v_b0, v_b1), + _mm_unpacklo_epi64(v_b2, v_b3)); + const __m256i v_d_w = _mm256_sub_epi16(v_a_w, v_b_w); + *sum = _mm256_add_epi32(*sum, _mm256_madd_epi16(v_d_w, v_d_w)); +} + +static INLINE void highbd_sse_w8x2_avx2(__m256i *sum, const uint16_t *a, + int a_stride, const uint16_t *b, + int b_stride) { + const __m256i v_a_w = yy_loadu2_128(a + a_stride, a); + const __m256i v_b_w = yy_loadu2_128(b + b_stride, b); + const __m256i v_d_w = _mm256_sub_epi16(v_a_w, v_b_w); + *sum = _mm256_add_epi32(*sum, _mm256_madd_epi16(v_d_w, v_d_w)); +} int64_t aom_highbd_sse_avx2(const uint8_t *a8, int a_stride, const uint8_t *b8, int b_stride, int width, int height) { int32_t y = 0; @@ -165,20 +248,7 @@ switch (width) { case 4: do { - const __m128i v_a0 = xx_loadl_64(a); - const __m128i v_a1 = xx_loadl_64(a + a_stride); - const __m128i v_a2 = xx_loadl_64(a + a_stride * 2); - const __m128i v_a3 = xx_loadl_64(a + a_stride * 3); - const __m128i v_b0 = xx_loadl_64(b); - const __m128i v_b1 = xx_loadl_64(b + b_stride); - const __m128i v_b2 = xx_loadl_64(b + b_stride * 2); - const __m128i v_b3 = xx_loadl_64(b + b_stride * 3); - const __m256i v_a_w = yy_set_m128i(_mm_unpacklo_epi64(v_a0, v_a1), - _mm_unpacklo_epi64(v_a2, v_a3)); - const __m256i v_b_w = yy_set_m128i(_mm_unpacklo_epi64(v_b0, v_b1), - _mm_unpacklo_epi64(v_b2, v_b3)); - const __m256i v_d_w = _mm256_sub_epi16(v_a_w, v_b_w); - sum = _mm256_add_epi32(sum, _mm256_madd_epi16(v_d_w, v_d_w)); + highbd_sse_w4x4_avx2(&sum, a, a_stride, b, b_stride); a += a_stride << 2; b += b_stride << 2; y += 4; @@ -187,10 +257,7 @@ break; case 8: do { - const __m256i v_a_w = yy_loadu2_128(a + a_stride, a); - const __m256i v_b_w = yy_loadu2_128(b + b_stride, b); - const __m256i v_d_w = _mm256_sub_epi16(v_a_w, v_b_w); - sum = _mm256_add_epi32(sum, _mm256_madd_epi16(v_d_w, v_d_w)); + highbd_sse_w8x2_avx2(&sum, a, a_stride, b, b_stride); a += a_stride << 1; b += b_stride << 1; y += 2; @@ -208,43 +275,98 @@ break; case 32: do { - highbd_sse_w16_avx2(&sum, a, b); - highbd_sse_w16_avx2(&sum, a + 16, b + 16); - a += a_stride; - b += b_stride; - y += 1; + int l = 0; + __m256i sum32 = _mm256_setzero_si256(); + do { + highbd_sse_w16_avx2(&sum32, a, b); + highbd_sse_w16_avx2(&sum32, a + 16, b + 16); + a += a_stride; + b += b_stride; + l += 1; + } while (l < 64 && l < (height - y)); + summary_32_avx2(&sum32, &sum); + y += 64; } while (y < height); - sse = summary_all_avx2(&sum); + sse = summary_4x64_avx2(sum); break; case 64: do { - highbd_sse_w16_avx2(&sum, a, b); - highbd_sse_w16_avx2(&sum, a + 16 * 1, b + 16 * 1); - highbd_sse_w16_avx2(&sum, a + 16 * 2, b + 16 * 2); - highbd_sse_w16_avx2(&sum, a + 16 * 3, b + 16 * 3); - a += a_stride; - b += b_stride; - y += 1; + int l = 0; + __m256i sum32 = _mm256_setzero_si256(); + do { + highbd_sse_w16_avx2(&sum32, a, b); + highbd_sse_w16_avx2(&sum32, a + 16 * 1, b + 16 * 1); + highbd_sse_w16_avx2(&sum32, a + 16 * 2, b + 16 * 2); + highbd_sse_w16_avx2(&sum32, a + 16 * 3, b + 16 * 3); + a += a_stride; + b += b_stride; + l += 1; + } while (l < 32 && l < (height - y)); + summary_32_avx2(&sum32, &sum); + y += 32; } while (y < height); - sse = summary_all_avx2(&sum); + sse = summary_4x64_avx2(sum); break; case 128: do { - highbd_sse_w16_avx2(&sum, a, b); - highbd_sse_w16_avx2(&sum, a + 16 * 1, b + 16 * 1); - highbd_sse_w16_avx2(&sum, a + 16 * 2, b + 16 * 2); - highbd_sse_w16_avx2(&sum, a + 16 * 3, b + 16 * 3); - highbd_sse_w16_avx2(&sum, a + 16 * 4, b + 16 * 4); - highbd_sse_w16_avx2(&sum, a + 16 * 5, b + 16 * 5); - highbd_sse_w16_avx2(&sum, a + 16 * 6, b + 16 * 6); - highbd_sse_w16_avx2(&sum, a + 16 * 7, b + 16 * 7); - a += a_stride; - b += b_stride; - y += 1; + int l = 0; + __m256i sum32 = _mm256_setzero_si256(); + do { + highbd_sse_w16_avx2(&sum32, a, b); + highbd_sse_w16_avx2(&sum32, a + 16 * 1, b + 16 * 1); + highbd_sse_w16_avx2(&sum32, a + 16 * 2, b + 16 * 2); + highbd_sse_w16_avx2(&sum32, a + 16 * 3, b + 16 * 3); + highbd_sse_w16_avx2(&sum32, a + 16 * 4, b + 16 * 4); + highbd_sse_w16_avx2(&sum32, a + 16 * 5, b + 16 * 5); + highbd_sse_w16_avx2(&sum32, a + 16 * 6, b + 16 * 6); + highbd_sse_w16_avx2(&sum32, a + 16 * 7, b + 16 * 7); + a += a_stride; + b += b_stride; + l += 1; + } while (l < 16 && l < (height - y)); + summary_32_avx2(&sum32, &sum); + y += 16; } while (y < height); - sse = summary_all_avx2(&sum); + sse = summary_4x64_avx2(sum); break; - default: break; + default: + if (width & 0x7) { + do { + int i = 0; + __m256i sum32 = _mm256_setzero_si256(); + do { + highbd_sse_w8x2_avx2(&sum32, a + i, a_stride, b + i, b_stride); + const uint16_t *a2 = a + i + (a_stride << 1); + const uint16_t *b2 = b + i + (b_stride << 1); + highbd_sse_w8x2_avx2(&sum32, a2, a_stride, b2, b_stride); + i += 8; + } while (i + 4 < width); + highbd_sse_w4x4_avx2(&sum32, a + i, a_stride, b + i, b_stride); + summary_32_avx2(&sum32, &sum); + a += a_stride << 2; + b += b_stride << 2; + y += 4; + } while (y < height); + } else { + do { + int l = 0; + __m256i sum32 = _mm256_setzero_si256(); + do { + int i = 0; + do { + highbd_sse_w8x2_avx2(&sum32, a + i, a_stride, b + i, b_stride); + i += 8; + } while (i < width); + a += a_stride << 1; + b += b_stride << 1; + l += 2; + } while (l < 8 && l < (height - y)); + summary_32_avx2(&sum32, &sum); + y += 8; + } while (y < height); + } + sse = summary_4x64_avx2(sum); + break; } return sse; }
diff --git a/aom_dsp/x86/sse_sse4.c b/aom_dsp/x86/sse_sse4.c index 8b5af84..0d45003 100644 --- a/aom_dsp/x86/sse_sse4.c +++ b/aom_dsp/x86/sse_sse4.c
@@ -28,6 +28,13 @@ return sum; } +static INLINE void summary_32_sse4(const __m128i *sum32, __m128i *sum64) { + const __m128i sum0 = _mm_cvtepu32_epi64(*sum32); + const __m128i sum1 = _mm_cvtepu32_epi64(_mm_srli_si128(*sum32, 8)); + *sum64 = _mm_add_epi64(sum0, *sum64); + *sum64 = _mm_add_epi64(sum1, *sum64); +} + static INLINE void sse_w16_sse4_1(__m128i *sum, const uint8_t *a, const uint8_t *b) { const __m128i v_a0 = xx_loadu_128(a); @@ -42,6 +49,28 @@ *sum = _mm_add_epi32(*sum, _mm_madd_epi16(v_d01_w, v_d01_w)); } +static INLINE void aom_sse4x2_sse4_1(const uint8_t *a, int a_stride, + const uint8_t *b, int b_stride, + __m128i *sum) { + const __m128i v_a0 = xx_loadl_32(a); + const __m128i v_a1 = xx_loadl_32(a + a_stride); + const __m128i v_b0 = xx_loadl_32(b); + const __m128i v_b1 = xx_loadl_32(b + b_stride); + const __m128i v_a_w = _mm_cvtepu8_epi16(_mm_unpacklo_epi32(v_a0, v_a1)); + const __m128i v_b_w = _mm_cvtepu8_epi16(_mm_unpacklo_epi32(v_b0, v_b1)); + const __m128i v_d_w = _mm_sub_epi16(v_a_w, v_b_w); + *sum = _mm_add_epi32(*sum, _mm_madd_epi16(v_d_w, v_d_w)); +} +static INLINE void aom_sse8_sse4_1(const uint8_t *a, const uint8_t *b, + __m128i *sum) { + const __m128i v_a0 = xx_loadl_64(a); + const __m128i v_b0 = xx_loadl_64(b); + const __m128i v_a_w = _mm_cvtepu8_epi16(v_a0); + const __m128i v_b_w = _mm_cvtepu8_epi16(v_b0); + const __m128i v_d_w = _mm_sub_epi16(v_a_w, v_b_w); + *sum = _mm_add_epi32(*sum, _mm_madd_epi16(v_d_w, v_d_w)); +} + int64_t aom_sse_sse4_1(const uint8_t *a, int a_stride, const uint8_t *b, int b_stride, int width, int height) { int y = 0; @@ -50,14 +79,7 @@ switch (width) { case 4: do { - const __m128i v_a0 = xx_loadl_32(a); - const __m128i v_a1 = xx_loadl_32(a + a_stride); - const __m128i v_b0 = xx_loadl_32(b); - const __m128i v_b1 = xx_loadl_32(b + b_stride); - const __m128i v_a_w = _mm_cvtepu8_epi16(_mm_unpacklo_epi32(v_a0, v_a1)); - const __m128i v_b_w = _mm_cvtepu8_epi16(_mm_unpacklo_epi32(v_b0, v_b1)); - const __m128i v_d_w = _mm_sub_epi16(v_a_w, v_b_w); - sum = _mm_add_epi32(sum, _mm_madd_epi16(v_d_w, v_d_w)); + aom_sse4x2_sse4_1(a, a_stride, b, b_stride, &sum); a += a_stride << 1; b += b_stride << 1; y += 2; @@ -66,12 +88,7 @@ break; case 8: do { - const __m128i v_a0 = xx_loadl_64(a); - const __m128i v_b0 = xx_loadl_64(b); - const __m128i v_a_w = _mm_cvtepu8_epi16(v_a0); - const __m128i v_b_w = _mm_cvtepu8_epi16(v_b0); - const __m128i v_d_w = _mm_sub_epi16(v_a_w, v_b_w); - sum = _mm_add_epi32(sum, _mm_madd_epi16(v_d_w, v_d_w)); + aom_sse8_sse4_1(a, b, &sum); a += a_stride; b += b_stride; y += 1; @@ -125,12 +142,52 @@ } while (y < height); sse = summary_all_sse4(&sum); break; - default: break; + default: + if (width & 0x07) { + do { + int i = 0; + do { + aom_sse8_sse4_1(a + i, b + i, &sum); + aom_sse8_sse4_1(a + i + a_stride, b + i + b_stride, &sum); + i += 8; + } while (i + 4 < width); + aom_sse4x2_sse4_1(a + i, a_stride, b + i, b_stride, &sum); + a += (a_stride << 1); + b += (b_stride << 1); + y += 2; + } while (y < height); + } else { + do { + int i = 0; + do { + aom_sse8_sse4_1(a + i, b + i, &sum); + i += 8; + } while (i < width); + a += a_stride; + b += b_stride; + y += 1; + } while (y < height); + } + sse = summary_all_sse4(&sum); + break; } return sse; } +static INLINE void highbd_sse_w4x2_sse4_1(__m128i *sum, const uint16_t *a, + int a_stride, const uint16_t *b, + int b_stride) { + const __m128i v_a0 = xx_loadl_64(a); + const __m128i v_a1 = xx_loadl_64(a + a_stride); + const __m128i v_b0 = xx_loadl_64(b); + const __m128i v_b1 = xx_loadl_64(b + b_stride); + const __m128i v_a_w = _mm_unpacklo_epi64(v_a0, v_a1); + const __m128i v_b_w = _mm_unpacklo_epi64(v_b0, v_b1); + const __m128i v_d_w = _mm_sub_epi16(v_a_w, v_b_w); + *sum = _mm_add_epi32(*sum, _mm_madd_epi16(v_d_w, v_d_w)); +} + static INLINE void highbd_sse_w8_sse4_1(__m128i *sum, const uint16_t *a, const uint16_t *b) { const __m128i v_a_w = xx_loadu_128(a); @@ -150,14 +207,7 @@ switch (width) { case 4: do { - const __m128i v_a0 = xx_loadl_64(a); - const __m128i v_a1 = xx_loadl_64(a + a_stride); - const __m128i v_b0 = xx_loadl_64(b); - const __m128i v_b1 = xx_loadl_64(b + b_stride); - const __m128i v_a_w = _mm_unpacklo_epi64(v_a0, v_a1); - const __m128i v_b_w = _mm_unpacklo_epi64(v_b0, v_b1); - const __m128i v_d_w = _mm_sub_epi16(v_a_w, v_b_w); - sum = _mm_add_epi32(sum, _mm_madd_epi16(v_d_w, v_d_w)); + highbd_sse_w4x2_sse4_1(&sum, a, a_stride, b, b_stride); a += a_stride << 1; b += b_stride << 1; y += 2; @@ -175,67 +225,126 @@ break; case 16: do { - highbd_sse_w8_sse4_1(&sum, a, b); - highbd_sse_w8_sse4_1(&sum, a + 8, b + 8); - a += a_stride; - b += b_stride; - y += 1; + int l = 0; + __m128i sum32 = _mm_setzero_si128(); + do { + highbd_sse_w8_sse4_1(&sum32, a, b); + highbd_sse_w8_sse4_1(&sum32, a + 8, b + 8); + a += a_stride; + b += b_stride; + l += 1; + } while (l < 64 && l < (height - y)); + summary_32_sse4(&sum32, &sum); + y += 64; } while (y < height); - sse = summary_all_sse4(&sum); + xx_storel_64(&sse, _mm_add_epi64(sum, _mm_srli_si128(sum, 8))); break; case 32: do { - highbd_sse_w8_sse4_1(&sum, a, b); - highbd_sse_w8_sse4_1(&sum, a + 8 * 1, b + 8 * 1); - highbd_sse_w8_sse4_1(&sum, a + 8 * 2, b + 8 * 2); - highbd_sse_w8_sse4_1(&sum, a + 8 * 3, b + 8 * 3); - a += a_stride; - b += b_stride; - y += 1; + int l = 0; + __m128i sum32 = _mm_setzero_si128(); + do { + highbd_sse_w8_sse4_1(&sum32, a, b); + highbd_sse_w8_sse4_1(&sum32, a + 8 * 1, b + 8 * 1); + highbd_sse_w8_sse4_1(&sum32, a + 8 * 2, b + 8 * 2); + highbd_sse_w8_sse4_1(&sum32, a + 8 * 3, b + 8 * 3); + a += a_stride; + b += b_stride; + l += 1; + } while (l < 32 && l < (height - y)); + summary_32_sse4(&sum32, &sum); + y += 32; } while (y < height); - sse = summary_all_sse4(&sum); + xx_storel_64(&sse, _mm_add_epi64(sum, _mm_srli_si128(sum, 8))); break; case 64: do { - highbd_sse_w8_sse4_1(&sum, a, b); - highbd_sse_w8_sse4_1(&sum, a + 8 * 1, b + 8 * 1); - highbd_sse_w8_sse4_1(&sum, a + 8 * 2, b + 8 * 2); - highbd_sse_w8_sse4_1(&sum, a + 8 * 3, b + 8 * 3); - highbd_sse_w8_sse4_1(&sum, a + 8 * 4, b + 8 * 4); - highbd_sse_w8_sse4_1(&sum, a + 8 * 5, b + 8 * 5); - highbd_sse_w8_sse4_1(&sum, a + 8 * 6, b + 8 * 6); - highbd_sse_w8_sse4_1(&sum, a + 8 * 7, b + 8 * 7); - a += a_stride; - b += b_stride; - y += 1; + int l = 0; + __m128i sum32 = _mm_setzero_si128(); + do { + highbd_sse_w8_sse4_1(&sum32, a, b); + highbd_sse_w8_sse4_1(&sum32, a + 8 * 1, b + 8 * 1); + highbd_sse_w8_sse4_1(&sum32, a + 8 * 2, b + 8 * 2); + highbd_sse_w8_sse4_1(&sum32, a + 8 * 3, b + 8 * 3); + highbd_sse_w8_sse4_1(&sum32, a + 8 * 4, b + 8 * 4); + highbd_sse_w8_sse4_1(&sum32, a + 8 * 5, b + 8 * 5); + highbd_sse_w8_sse4_1(&sum32, a + 8 * 6, b + 8 * 6); + highbd_sse_w8_sse4_1(&sum32, a + 8 * 7, b + 8 * 7); + a += a_stride; + b += b_stride; + l += 1; + } while (l < 16 && l < (height - y)); + summary_32_sse4(&sum32, &sum); + y += 16; } while (y < height); - sse = summary_all_sse4(&sum); + xx_storel_64(&sse, _mm_add_epi64(sum, _mm_srli_si128(sum, 8))); break; case 128: do { - highbd_sse_w8_sse4_1(&sum, a, b); - highbd_sse_w8_sse4_1(&sum, a + 8 * 1, b + 8 * 1); - highbd_sse_w8_sse4_1(&sum, a + 8 * 2, b + 8 * 2); - highbd_sse_w8_sse4_1(&sum, a + 8 * 3, b + 8 * 3); - highbd_sse_w8_sse4_1(&sum, a + 8 * 4, b + 8 * 4); - highbd_sse_w8_sse4_1(&sum, a + 8 * 5, b + 8 * 5); - highbd_sse_w8_sse4_1(&sum, a + 8 * 6, b + 8 * 6); - highbd_sse_w8_sse4_1(&sum, a + 8 * 7, b + 8 * 7); - highbd_sse_w8_sse4_1(&sum, a + 8 * 8, b + 8 * 8); - highbd_sse_w8_sse4_1(&sum, a + 8 * 9, b + 8 * 9); - highbd_sse_w8_sse4_1(&sum, a + 8 * 10, b + 8 * 10); - highbd_sse_w8_sse4_1(&sum, a + 8 * 11, b + 8 * 11); - highbd_sse_w8_sse4_1(&sum, a + 8 * 12, b + 8 * 12); - highbd_sse_w8_sse4_1(&sum, a + 8 * 13, b + 8 * 13); - highbd_sse_w8_sse4_1(&sum, a + 8 * 14, b + 8 * 14); - highbd_sse_w8_sse4_1(&sum, a + 8 * 15, b + 8 * 15); - a += a_stride; - b += b_stride; - y += 1; + int l = 0; + __m128i sum32 = _mm_setzero_si128(); + do { + highbd_sse_w8_sse4_1(&sum32, a, b); + highbd_sse_w8_sse4_1(&sum32, a + 8 * 1, b + 8 * 1); + highbd_sse_w8_sse4_1(&sum32, a + 8 * 2, b + 8 * 2); + highbd_sse_w8_sse4_1(&sum32, a + 8 * 3, b + 8 * 3); + highbd_sse_w8_sse4_1(&sum32, a + 8 * 4, b + 8 * 4); + highbd_sse_w8_sse4_1(&sum32, a + 8 * 5, b + 8 * 5); + highbd_sse_w8_sse4_1(&sum32, a + 8 * 6, b + 8 * 6); + highbd_sse_w8_sse4_1(&sum32, a + 8 * 7, b + 8 * 7); + highbd_sse_w8_sse4_1(&sum32, a + 8 * 8, b + 8 * 8); + highbd_sse_w8_sse4_1(&sum32, a + 8 * 9, b + 8 * 9); + highbd_sse_w8_sse4_1(&sum32, a + 8 * 10, b + 8 * 10); + highbd_sse_w8_sse4_1(&sum32, a + 8 * 11, b + 8 * 11); + highbd_sse_w8_sse4_1(&sum32, a + 8 * 12, b + 8 * 12); + highbd_sse_w8_sse4_1(&sum32, a + 8 * 13, b + 8 * 13); + highbd_sse_w8_sse4_1(&sum32, a + 8 * 14, b + 8 * 14); + highbd_sse_w8_sse4_1(&sum32, a + 8 * 15, b + 8 * 15); + a += a_stride; + b += b_stride; + l += 1; + } while (l < 8 && l < (height - y)); + summary_32_sse4(&sum32, &sum); + y += 8; } while (y < height); - sse = summary_all_sse4(&sum); + xx_storel_64(&sse, _mm_add_epi64(sum, _mm_srli_si128(sum, 8))); break; - default: break; + default: + if (width & 0x7) { + do { + __m128i sum32 = _mm_setzero_si128(); + int i = 0; + do { + highbd_sse_w8_sse4_1(&sum32, a + i, b + i); + highbd_sse_w8_sse4_1(&sum32, a + i + a_stride, b + i + b_stride); + i += 8; + } while (i + 4 < width); + highbd_sse_w4x2_sse4_1(&sum32, a + i, a_stride, b + i, b_stride); + a += (a_stride << 1); + b += (b_stride << 1); + y += 2; + summary_32_sse4(&sum32, &sum); + } while (y < height); + } else { + do { + int l = 0; + __m128i sum32 = _mm_setzero_si128(); + do { + int i = 0; + do { + highbd_sse_w8_sse4_1(&sum32, a + i, b + i); + i += 8; + } while (i < width); + a += a_stride; + b += b_stride; + l += 1; + } while (l < 8 && l < (height - y)); + summary_32_sse4(&sum32, &sum); + y += 8; + } while (y < height); + } + xx_storel_64(&sse, _mm_add_epi64(sum, _mm_srli_si128(sum, 8))); + break; } return sse; }
diff --git a/test/sum_squares_test.cc b/test/sum_squares_test.cc index f109984..cb518c8 100644 --- a/test/sum_squares_test.cc +++ b/test/sum_squares_test.cc
@@ -27,6 +27,10 @@ using libaom_test::ACMRandom; using libaom_test::FunctionEquivalenceTest; +using ::testing::Combine; +using ::testing::Range; +using ::testing::Values; +using ::testing::ValuesIn; namespace { const int kNumIterations = 10000; @@ -225,4 +229,142 @@ aom_sum_squares_i16_c, aom_sum_squares_i16_sse2))); #endif // HAVE_SSE2 + +typedef int64_t (*sse_func)(const uint8_t *a, int a_stride, const uint8_t *b, + int b_stride, int width, int height); +typedef libaom_test::FuncParam<sse_func> TestSSEFuncs; + +typedef ::testing::tuple<TestSSEFuncs, int> SSETestParam; + +class SSETest : public ::testing::TestWithParam<SSETestParam> { + public: + virtual ~SSETest() {} + virtual void SetUp() { + params_ = GET_PARAM(0); + width_ = GET_PARAM(1); + isHbd_ = params_.ref_func == aom_highbd_sse_c; + rnd_.Reset(ACMRandom::DeterministicSeed()); + src_ = reinterpret_cast<uint8_t *>(aom_memalign(32, 256 * 256 * 2)); + ref_ = reinterpret_cast<uint8_t *>(aom_memalign(32, 256 * 256 * 2)); + ASSERT_TRUE(src_ != NULL); + ASSERT_TRUE(ref_ != NULL); + } + + virtual void TearDown() { + libaom_test::ClearSystemState(); + aom_free(src_); + aom_free(ref_); + } + void RunTest(int isRandom, int width, int height); + + void GenRandomData(int width, int height, int stride) { + uint16_t *pSrc = (uint16_t *)src_; + uint16_t *pRef = (uint16_t *)ref_; + const int msb = 11; // Up to 12 bit input + const int limit = 1 << (msb + 1); + for (int ii = 0; ii < height; ii++) { + for (int jj = 0; jj < width; jj++) { + if (!isHbd_) { + src_[ii * stride + jj] = rnd_.Rand8(); + ref_[ii * stride + jj] = rnd_.Rand8(); + } else { + pSrc[ii * stride + jj] = rnd_(limit); + pRef[ii * stride + jj] = rnd_(limit); + } + } + } + } + + void GenExtremeData(int width, int height, int stride, uint8_t *data, + int16_t val) { + uint16_t *pData = (uint16_t *)data; + for (int ii = 0; ii < height; ii++) { + for (int jj = 0; jj < width; jj++) { + if (!isHbd_) { + data[ii * stride + jj] = (uint8_t)val; + } else { + pData[ii * stride + jj] = val; + } + } + } + } + + protected: + int isHbd_; + int width_; + TestSSEFuncs params_; + uint8_t *src_; + uint8_t *ref_; + ACMRandom rnd_; +}; + +void SSETest::RunTest(int isRandom, int width, int height) { + int failed = 0; + for (int k = 0; k < 3; k++) { + int stride = 4 << rnd_(7); // Up to 256 stride + while (stride < width) { // Make sure it's valid + stride = 4 << rnd_(7); + } + if (isRandom) { + GenRandomData(width, height, stride); + } else { + const int msb = isHbd_ ? 12 : 8; // Up to 12 bit input + const int limit = (1 << msb) - 1; + if (k == 0) { + GenExtremeData(width, height, stride, src_, 0); + GenExtremeData(width, height, stride, ref_, limit); + } else { + GenExtremeData(width, height, stride, src_, limit); + GenExtremeData(width, height, stride, ref_, 0); + } + } + int64_t res_ref, res_tst; + uint8_t *pSrc = src_; + uint8_t *pRef = ref_; + if (isHbd_) { + pSrc = CONVERT_TO_BYTEPTR(src_); + pRef = CONVERT_TO_BYTEPTR(ref_); + } + res_ref = params_.ref_func(pSrc, stride, pRef, stride, width, height); + + ASM_REGISTER_STATE_CHECK( + res_tst = params_.tst_func(pSrc, stride, pRef, stride, width, height)); + + if (!failed) { + failed = res_ref != res_tst; + EXPECT_EQ(res_ref, res_tst) + << "Error:" << (isHbd_ ? "hbd " : " ") << k << " SSE Test [" << width + << "x" << height << "] C output does not match optimized output."; + } + } +} + +TEST_P(SSETest, OperationCheck) { + for (int height = 4; height <= 128; height += 4) { + RunTest(1, width_, height); // GenRandomData + } +} + +TEST_P(SSETest, ExtremeValues) { + for (int height = 4; height <= 128; height += 4) { + RunTest(0, width_, height); + } +} + +#if HAVE_SSE4_1 +TestSSEFuncs sse_sse4[] = { TestSSEFuncs(&aom_sse_c, &aom_sse_sse4_1), + TestSSEFuncs(&aom_highbd_sse_c, + &aom_highbd_sse_sse4_1) }; +INSTANTIATE_TEST_CASE_P(SSE4_1, SSETest, + Combine(ValuesIn(sse_sse4), Range(4, 129, 4))); +#endif // HAVE_SSE4_1 + +#if HAVE_AVX2 + +TestSSEFuncs sse_avx2[] = { TestSSEFuncs(&aom_sse_c, &aom_sse_avx2), + TestSSEFuncs(&aom_highbd_sse_c, + &aom_highbd_sse_avx2) }; +INSTANTIATE_TEST_CASE_P(AVX2, SSETest, + Combine(ValuesIn(sse_avx2), Range(4, 129, 4))); +#endif // HAVE_AVX2 } // namespace