Add SSE2 variant for mse_wxh_16bit in cdef Added SSE2 variant for mse_wxh_16bit function and unit test (MseWxHTest). Module level gains: BLOCKSIZE Gain w.r.t. C 8x8 2.7x 8x4 2.5x 4x8 2.5x 4x4 2.2x Change-Id: I9cdc4fc609ed074626e6533e14b5385564253523
diff --git a/aom_dsp/aom_dsp_rtcd_defs.pl b/aom_dsp/aom_dsp_rtcd_defs.pl index 79fba0a..d6b27a6 100755 --- a/aom_dsp/aom_dsp_rtcd_defs.pl +++ b/aom_dsp/aom_dsp_rtcd_defs.pl
@@ -1130,7 +1130,7 @@ add_proto qw/unsigned int/, "aom_variance4x2", "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int ref_stride, unsigned int *sse"; add_proto qw/uint64_t/, "aom_mse_wxh_16bit", "uint8_t *dst, int dstride,uint16_t *src, int sstride, int w, int h"; - specialize qw/aom_mse_wxh_16bit avx2/; + specialize qw/aom_mse_wxh_16bit sse2 avx2/; foreach (@block_sizes) { ($w, $h) = @$_;
diff --git a/aom_dsp/x86/variance_sse2.c b/aom_dsp/x86/variance_sse2.c index 97f71fc..1a24a37 100644 --- a/aom_dsp/x86/variance_sse2.c +++ b/aom_dsp/x86/variance_sse2.c
@@ -756,3 +756,98 @@ } while (i < height); } } + +uint64_t aom_mse_4xh_16bit_sse2(uint8_t *dst, int dstride, uint16_t *src, + int sstride, int h) { + uint64_t sum = 0; + __m128i dst0_8x8, dst1_8x8, dst_16x8; + __m128i src0_16x4, src1_16x4, src_16x8; + __m128i res0_32x4, res1_32x4, res0_64x4, res1_64x4, res2_64x4, res3_64x4; + __m128i sub_result_16x8; + const __m128i zeros = _mm_setzero_si128(); + __m128i square_result = _mm_setzero_si128(); + for (int i = 0; i < h; i += 2) { + dst0_8x8 = _mm_cvtsi32_si128(*(uint32_t const *)(&dst[(i + 0) * dstride])); + dst1_8x8 = _mm_cvtsi32_si128(*(uint32_t const *)(&dst[(i + 1) * dstride])); + dst_16x8 = _mm_unpacklo_epi8(_mm_unpacklo_epi32(dst0_8x8, dst1_8x8), zeros); + + src0_16x4 = _mm_loadl_epi64((__m128i const *)(&src[(i + 0) * sstride])); + src1_16x4 = _mm_loadl_epi64((__m128i const *)(&src[(i + 1) * sstride])); + src_16x8 = _mm_unpacklo_epi64(src0_16x4, src1_16x4); + + sub_result_16x8 = _mm_sub_epi16(src_16x8, dst_16x8); + + res0_32x4 = _mm_unpacklo_epi16(sub_result_16x8, zeros); + res1_32x4 = _mm_unpackhi_epi16(sub_result_16x8, zeros); + + res0_32x4 = _mm_madd_epi16(res0_32x4, res0_32x4); + res1_32x4 = _mm_madd_epi16(res1_32x4, res1_32x4); + + res0_64x4 = _mm_unpacklo_epi32(res0_32x4, zeros); + res1_64x4 = _mm_unpackhi_epi32(res0_32x4, zeros); + res2_64x4 = _mm_unpacklo_epi32(res1_32x4, zeros); + res3_64x4 = _mm_unpackhi_epi32(res1_32x4, zeros); + + square_result = _mm_add_epi64( + square_result, + _mm_add_epi64( + _mm_add_epi64(_mm_add_epi64(res0_64x4, res1_64x4), res2_64x4), + res3_64x4)); + } + const __m128i sum_1x64 = + _mm_add_epi64(square_result, _mm_srli_si128(square_result, 8)); + xx_storel_64(&sum, sum_1x64); + return sum; +} + +uint64_t aom_mse_8xh_16bit_sse2(uint8_t *dst, int dstride, uint16_t *src, + int sstride, int h) { + uint64_t sum = 0; + __m128i dst_8x8, dst_16x8; + __m128i src_16x8; + __m128i res0_32x4, res1_32x4, res0_64x4, res1_64x4, res2_64x4, res3_64x4; + __m128i sub_result_16x8; + const __m128i zeros = _mm_setzero_si128(); + __m128i square_result = _mm_setzero_si128(); + + for (int i = 0; i < h; i++) { + dst_8x8 = _mm_loadl_epi64((__m128i const *)(&dst[(i + 0) * dstride])); + dst_16x8 = _mm_unpacklo_epi8(dst_8x8, zeros); + + src_16x8 = _mm_loadu_si128((__m128i *)&src[i * sstride]); + + sub_result_16x8 = _mm_sub_epi16(src_16x8, dst_16x8); + + res0_32x4 = _mm_unpacklo_epi16(sub_result_16x8, zeros); + res1_32x4 = _mm_unpackhi_epi16(sub_result_16x8, zeros); + + res0_32x4 = _mm_madd_epi16(res0_32x4, res0_32x4); + res1_32x4 = _mm_madd_epi16(res1_32x4, res1_32x4); + + res0_64x4 = _mm_unpacklo_epi32(res0_32x4, zeros); + res1_64x4 = _mm_unpackhi_epi32(res0_32x4, zeros); + res2_64x4 = _mm_unpacklo_epi32(res1_32x4, zeros); + res3_64x4 = _mm_unpackhi_epi32(res1_32x4, zeros); + + square_result = _mm_add_epi64( + square_result, + _mm_add_epi64( + _mm_add_epi64(_mm_add_epi64(res0_64x4, res1_64x4), res2_64x4), + res3_64x4)); + } + const __m128i sum_1x64 = + _mm_add_epi64(square_result, _mm_srli_si128(square_result, 8)); + xx_storel_64(&sum, sum_1x64); + return sum; +} + +uint64_t aom_mse_wxh_16bit_sse2(uint8_t *dst, int dstride, uint16_t *src, + int sstride, int w, int h) { + assert((w == 8 || w == 4) && (h == 8 || h == 4) && + "w=8/4 and h=8/4 must satisfy"); + switch (w) { + case 4: return aom_mse_4xh_16bit_sse2(dst, dstride, src, sstride, h); + case 8: return aom_mse_8xh_16bit_sse2(dst, dstride, src, sstride, h); + default: assert(0 && "unsupported width"); return -1; + } +}
diff --git a/test/variance_test.cc b/test/variance_test.cc index 4c016af..9932bef 100644 --- a/test/variance_test.cc +++ b/test/variance_test.cc
@@ -1172,7 +1172,7 @@ params_.bit_depth, elapsed_time); } -typedef MseWxHTestClass<MseWxH16bitFunc> AvxMseWxHTest; +typedef MseWxHTestClass<MseWxH16bitFunc> MseWxHTest; typedef MainTestClass<Get4x4SseFunc> AvxSseTest; typedef MainTestClass<VarianceMxNFunc> AvxMseTest; typedef MainTestClass<VarianceMxNFunc> AvxVarianceTest; @@ -1181,11 +1181,12 @@ typedef SubpelVarianceTest<DistWtdSubpixAvgVarMxNFunc> AvxDistWtdSubpelAvgVarianceTest; typedef ObmcVarianceTest<ObmcSubpelVarFunc> AvxObmcSubpelVarianceTest; +typedef TestParams<MseWxH16bitFunc> MseWxHParams; TEST_P(AvxSseTest, RefSse) { RefTestSse(); } TEST_P(AvxSseTest, MaxSse) { MaxTestSse(); } -TEST_P(AvxMseWxHTest, RefMse) { RefMatchTestMse(); } -TEST_P(AvxMseWxHTest, DISABLED_SpeedMse) { SpeedTest(); } +TEST_P(MseWxHTest, RefMse) { RefMatchTestMse(); } +TEST_P(MseWxHTest, DISABLED_SpeedMse) { SpeedTest(); } TEST_P(AvxMseTest, RefMse) { RefTestMse(); } TEST_P(AvxMseTest, MaxMse) { MaxTestMse(); } TEST_P(AvxVarianceTest, Zero) { ZeroTest(); } @@ -1795,6 +1796,13 @@ #endif // CONFIG_AV1_HIGHBITDEPTH #if HAVE_SSE2 +INSTANTIATE_TEST_SUITE_P( + SSE2, MseWxHTest, + ::testing::Values(MseWxHParams(3, 3, &aom_mse_wxh_16bit_sse2, 8), + MseWxHParams(3, 2, &aom_mse_wxh_16bit_sse2, 8), + MseWxHParams(2, 3, &aom_mse_wxh_16bit_sse2, 8), + MseWxHParams(2, 2, &aom_mse_wxh_16bit_sse2, 8))); + INSTANTIATE_TEST_SUITE_P(SSE2, SumOfSquaresTest, ::testing::Values(aom_get_mb_ss_sse2)); @@ -2380,9 +2388,8 @@ #if HAVE_AVX2 -typedef TestParams<MseWxH16bitFunc> MseWxHParams; INSTANTIATE_TEST_SUITE_P( - AVX2, AvxMseWxHTest, + AVX2, MseWxHTest, ::testing::Values(MseWxHParams(3, 3, &aom_mse_wxh_16bit_avx2, 8), MseWxHParams(3, 2, &aom_mse_wxh_16bit_avx2, 8), MseWxHParams(2, 3, &aom_mse_wxh_16bit_avx2, 8),