Add pixel_proj_error simd code 1. add low bitdepth pixel_proj_error sse4_1 code 2. add low bitdepth pixel_proj_error avx2 code 3. add low bitdepth pixel_proj_error unittest Speed up about 0.6% without rd change test sequence: BasketballDrill_832x480_50.y4m test command line:./aomenc --cpu-used=1 --psnr -D \ -q --end-usage=vbr --target-bitrate=800 --limit=20 \ BasketballDrill_832x480_50.y4m -otest.webm Change-Id: Ia50e6c7ff70da24e5564cc52069e54b585d7c8c1
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl index 0bff779..f149ab0 100755 --- a/av1/common/av1_rtcd_defs.pl +++ b/av1/common/av1_rtcd_defs.pl
@@ -253,6 +253,9 @@ add_proto qw/void av1_compute_stats/, "int wiener_win, const uint8_t *dgd8, const uint8_t *src8, int h_start, int h_end, int v_start, int v_end, int dgd_stride, int src_stride, double *M, double *H"; specialize qw/av1_compute_stats sse4_1 avx2/; + + add_proto qw/int64_t av1_lowbd_pixel_proj_error/, " const uint8_t *src8, int width, int height, int src_stride, const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride, int32_t *flt1, int flt1_stride, int xq[2], const sgr_params_type *params"; + specialize qw/av1_lowbd_pixel_proj_error sse4_1 avx2/; } # end encoder functions
diff --git a/av1/encoder/pickrst.c b/av1/encoder/pickrst.c index 86bfe04..f6a1bae 100644 --- a/av1/encoder/pickrst.c +++ b/av1/encoder/pickrst.c
@@ -181,6 +181,77 @@ return sse_restoration_unit(limits, rsc->src, rsc->dst, plane, highbd); } +int64_t av1_lowbd_pixel_proj_error_c(const uint8_t *src8, int width, int height, + int src_stride, const uint8_t *dat8, + int dat_stride, int32_t *flt0, + int flt0_stride, int32_t *flt1, + int flt1_stride, int xq[2], + const sgr_params_type *params) { + int i, j; + const uint8_t *src = src8; + const uint8_t *dat = dat8; + int64_t err = 0; + if (params->r[0] > 0 && params->r[1] > 0) { + for (i = 0; i < height; ++i) { + for (j = 0; j < width; ++j) { + assert(flt1[j] < (1 << 15) && flt1[j] > -(1 << 15)); + assert(flt0[j] < (1 << 15) && flt0[j] > -(1 << 15)); + const int32_t u = (int32_t)(dat[j] << SGRPROJ_RST_BITS); + int32_t v = u << SGRPROJ_PRJ_BITS; + v += xq[0] * (flt0[j] - u) + xq[1] * (flt1[j] - u); + const int32_t e = + ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) - src[j]; + err += e * e; + } + dat += dat_stride; + src += src_stride; + flt0 += flt0_stride; + flt1 += flt1_stride; + } + } else if (params->r[0] > 0) { + for (i = 0; i < height; ++i) { + for (j = 0; j < width; ++j) { + assert(flt0[j] < (1 << 15) && flt0[j] > -(1 << 15)); + const int32_t u = (int32_t)(dat[j] << SGRPROJ_RST_BITS); + int32_t v = u << SGRPROJ_PRJ_BITS; + v += xq[0] * (flt0[j] - u); + const int32_t e = + ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) - src[j]; + err += e * e; + } + dat += dat_stride; + src += src_stride; + flt0 += flt0_stride; + } + } else if (params->r[1] > 0) { + for (i = 0; i < height; ++i) { + for (j = 0; j < width; ++j) { + assert(flt1[j] < (1 << 15) && flt1[j] > -(1 << 15)); + const int32_t u = (int32_t)(dat[j] << SGRPROJ_RST_BITS); + int32_t v = u << SGRPROJ_PRJ_BITS; + v += xq[1] * (flt1[j] - u); + const int32_t e = + ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) - src[j]; + err += e * e; + } + dat += dat_stride; + src += src_stride; + flt1 += flt1_stride; + } + } else { + for (i = 0; i < height; ++i) { + for (j = 0; j < width; ++j) { + const int32_t e = (int32_t)(dat[j]) - src[j]; + err += e * e; + } + dat += dat_stride; + src += src_stride; + } + } + + return err; +} + static int64_t get_pixel_proj_error(const uint8_t *src8, int width, int height, int src_stride, const uint8_t *dat8, int dat_stride, int use_highbitdepth, @@ -192,21 +263,9 @@ int xq[2]; decode_xq(xqd, xq, params); if (!use_highbitdepth) { - const uint8_t *src = src8; - const uint8_t *dat = dat8; - for (i = 0; i < height; ++i) { - for (j = 0; j < width; ++j) { - const int32_t u = - (int32_t)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS); - int32_t v = u << SGRPROJ_PRJ_BITS; - if (params->r[0] > 0) v += xq[0] * (flt0[i * flt0_stride + j] - u); - if (params->r[1] > 0) v += xq[1] * (flt1[i * flt1_stride + j] - u); - const int32_t e = - ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) - - src[i * src_stride + j]; - err += e * e; - } - } + err = av1_lowbd_pixel_proj_error(src8, width, height, src_stride, dat8, + dat_stride, flt0, flt0_stride, flt1, + flt1_stride, xq, params); } else { const uint16_t *src = CONVERT_TO_SHORTPTR(src8); const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
diff --git a/av1/encoder/x86/pickrst_avx2.c b/av1/encoder/x86/pickrst_avx2.c index 31a063b..548f6c3 100644 --- a/av1/encoder/x86/pickrst_avx2.c +++ b/av1/encoder/x86/pickrst_avx2.c
@@ -227,3 +227,176 @@ dgd_stride, src_stride, M, H); } } + +static INLINE __m256i pair_set_epi16(uint16_t a, uint16_t b) { + return _mm256_set1_epi32( + (int32_t)(((uint16_t)(a)) | (((uint32_t)(b)) << 16))); +} + +int64_t av1_lowbd_pixel_proj_error_avx2( + const uint8_t *src8, int width, int height, int src_stride, + const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride, + int32_t *flt1, int flt1_stride, int xq[2], const sgr_params_type *params) { + int i, j, k; + const int32_t shift = SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS; + const __m256i rounding = _mm256_set1_epi32(1 << (shift - 1)); + __m256i sum64 = _mm256_setzero_si256(); + const uint8_t *src = src8; + const uint8_t *dat = dat8; + int64_t err = 0; + if (params->r[0] > 0 && params->r[1] > 0) { + __m256i xq_coeff = pair_set_epi16(xq[0], xq[1]); + for (i = 0; i < height; ++i) { + __m256i sum32 = _mm256_setzero_si256(); + for (j = 0; j <= width - 16; j += 16) { + const __m256i d0 = _mm256_cvtepu8_epi16(xx_loadu_128(dat + j)); + const __m256i s0 = _mm256_cvtepu8_epi16(xx_loadu_128(src + j)); + const __m256i flt0_16b = _mm256_permute4x64_epi64( + _mm256_packs_epi32(yy_loadu_256(flt0 + j), + yy_loadu_256(flt0 + j + 8)), + 0xd8); + const __m256i flt1_16b = _mm256_permute4x64_epi64( + _mm256_packs_epi32(yy_loadu_256(flt1 + j), + yy_loadu_256(flt1 + j + 8)), + 0xd8); + const __m256i u0 = _mm256_slli_epi16(d0, SGRPROJ_RST_BITS); + const __m256i flt0_0_sub_u = _mm256_sub_epi16(flt0_16b, u0); + const __m256i flt1_0_sub_u = _mm256_sub_epi16(flt1_16b, u0); + const __m256i v0 = _mm256_madd_epi16( + xq_coeff, _mm256_unpacklo_epi16(flt0_0_sub_u, flt1_0_sub_u)); + const __m256i v1 = _mm256_madd_epi16( + xq_coeff, _mm256_unpackhi_epi16(flt0_0_sub_u, flt1_0_sub_u)); + const __m256i vr0 = + _mm256_srai_epi32(_mm256_add_epi32(v0, rounding), shift); + const __m256i vr1 = + _mm256_srai_epi32(_mm256_add_epi32(v1, rounding), shift); + const __m256i e0 = _mm256_sub_epi16( + _mm256_add_epi16(_mm256_packs_epi32(vr0, vr1), d0), s0); + const __m256i err0 = _mm256_madd_epi16(e0, e0); + sum32 = _mm256_add_epi32(sum32, err0); + } + for (k = j; k < width; ++k) { + const int32_t u = (int32_t)(dat[k] << SGRPROJ_RST_BITS); + int32_t v = xq[0] * (flt0[k] - u) + xq[1] * (flt1[k] - u); + const int32_t e = ROUND_POWER_OF_TWO(v, shift) + dat[k] - src[k]; + err += e * e; + } + dat += dat_stride; + src += src_stride; + flt0 += flt0_stride; + flt1 += flt1_stride; + const __m256i sum64_0 = + _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sum32)); + const __m256i sum64_1 = + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(sum32, 1)); + sum64 = _mm256_add_epi64(sum64, sum64_0); + sum64 = _mm256_add_epi64(sum64, sum64_1); + } + } else if (params->r[0] > 0) { + __m256i xq_coeff = pair_set_epi16(xq[0], -(xq[0] << SGRPROJ_RST_BITS)); + for (i = 0; i < height; ++i) { + __m256i sum32 = _mm256_setzero_si256(); + for (j = 0; j <= width - 16; j += 16) { + const __m256i d0 = _mm256_cvtepu8_epi16(xx_loadu_128(dat + j)); + const __m256i s0 = _mm256_cvtepu8_epi16(xx_loadu_128(src + j)); + const __m256i flt0_16b = _mm256_permute4x64_epi64( + _mm256_packs_epi32(yy_loadu_256(flt0 + j), + yy_loadu_256(flt0 + j + 8)), + 0xd8); + const __m256i v0 = + _mm256_madd_epi16(xq_coeff, _mm256_unpacklo_epi16(flt0_16b, d0)); + const __m256i v1 = + _mm256_madd_epi16(xq_coeff, _mm256_unpackhi_epi16(flt0_16b, d0)); + const __m256i vr0 = + _mm256_srai_epi32(_mm256_add_epi32(v0, rounding), shift); + const __m256i vr1 = + _mm256_srai_epi32(_mm256_add_epi32(v1, rounding), shift); + const __m256i e0 = _mm256_sub_epi16( + _mm256_add_epi16(_mm256_packs_epi32(vr0, vr1), d0), s0); + const __m256i err0 = _mm256_madd_epi16(e0, e0); + sum32 = _mm256_add_epi32(sum32, err0); + } + for (k = j; k < width; ++k) { + const int32_t u = (int32_t)(dat[k] << SGRPROJ_RST_BITS); + int32_t v = xq[0] * (flt0[k] - u); + const int32_t e = ROUND_POWER_OF_TWO(v, shift) + dat[k] - src[k]; + err += e * e; + } + dat += dat_stride; + src += src_stride; + flt0 += flt0_stride; + const __m256i sum64_0 = + _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sum32)); + const __m256i sum64_1 = + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(sum32, 1)); + sum64 = _mm256_add_epi64(sum64, sum64_0); + sum64 = _mm256_add_epi64(sum64, sum64_1); + } + } else if (params->r[1] > 0) { + __m256i xq_coeff = pair_set_epi16(xq[1], -(xq[1] << SGRPROJ_RST_BITS)); + for (i = 0; i < height; ++i) { + __m256i sum32 = _mm256_setzero_si256(); + for (j = 0; j <= width - 16; j += 16) { + const __m256i d0 = _mm256_cvtepu8_epi16(xx_loadu_128(dat + j)); + const __m256i s0 = _mm256_cvtepu8_epi16(xx_loadu_128(src + j)); + const __m256i flt1_16b = _mm256_permute4x64_epi64( + _mm256_packs_epi32(yy_loadu_256(flt1 + j), + yy_loadu_256(flt1 + j + 8)), + 0xd8); + const __m256i v0 = + _mm256_madd_epi16(xq_coeff, _mm256_unpacklo_epi16(flt1_16b, d0)); + const __m256i v1 = + _mm256_madd_epi16(xq_coeff, _mm256_unpackhi_epi16(flt1_16b, d0)); + const __m256i vr0 = + _mm256_srai_epi32(_mm256_add_epi32(v0, rounding), shift); + const __m256i vr1 = + _mm256_srai_epi32(_mm256_add_epi32(v1, rounding), shift); + const __m256i e0 = _mm256_sub_epi16( + _mm256_add_epi16(_mm256_packs_epi32(vr0, vr1), d0), s0); + const __m256i err0 = _mm256_madd_epi16(e0, e0); + sum32 = _mm256_add_epi32(sum32, err0); + } + for (k = j; k < width; ++k) { + const int32_t u = (int32_t)(dat[k] << SGRPROJ_RST_BITS); + int32_t v = xq[1] * (flt1[k] - u); + const int32_t e = ROUND_POWER_OF_TWO(v, shift) + dat[k] - src[k]; + err += e * e; + } + dat += dat_stride; + src += src_stride; + flt1 += flt1_stride; + const __m256i sum64_0 = + _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sum32)); + const __m256i sum64_1 = + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(sum32, 1)); + sum64 = _mm256_add_epi64(sum64, sum64_0); + sum64 = _mm256_add_epi64(sum64, sum64_1); + } + } else { + __m256i sum32 = _mm256_setzero_si256(); + for (i = 0; i < height; ++i) { + for (j = 0; j <= width - 16; j += 16) { + const __m256i d0 = _mm256_cvtepu8_epi16(xx_loadu_128(dat + j)); + const __m256i s0 = _mm256_cvtepu8_epi16(xx_loadu_128(src + j)); + const __m256i diff0 = _mm256_sub_epi16(d0, s0); + const __m256i err0 = _mm256_madd_epi16(diff0, diff0); + sum32 = _mm256_add_epi32(sum32, err0); + } + for (k = j; k < width; ++k) { + const int32_t e = (int32_t)(dat[k]) - src[k]; + err += e * e; + } + dat += dat_stride; + src += src_stride; + } + const __m256i sum64_0 = + _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sum32)); + const __m256i sum64_1 = + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(sum32, 1)); + sum64 = _mm256_add_epi64(sum64_0, sum64_1); + } + int64_t sum[4]; + yy_storeu_256(sum, sum64); + err += sum[0] + sum[1] + sum[2] + sum[3]; + return err; +}
diff --git a/av1/encoder/x86/pickrst_sse4.c b/av1/encoder/x86/pickrst_sse4.c index 4f6cde0..04e4d1a 100644 --- a/av1/encoder/x86/pickrst_sse4.c +++ b/av1/encoder/x86/pickrst_sse4.c
@@ -230,3 +230,160 @@ dgd_stride, src_stride, M, H); } } + +static INLINE __m128i pair_set_epi16(uint16_t a, uint16_t b) { + return _mm_set1_epi32((int32_t)(((uint16_t)(a)) | (((uint32_t)(b)) << 16))); +} + +int64_t av1_lowbd_pixel_proj_error_sse4_1( + const uint8_t *src8, int width, int height, int src_stride, + const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride, + int32_t *flt1, int flt1_stride, int xq[2], const sgr_params_type *params) { + int i, j, k; + const int32_t shift = SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS; + const __m128i rounding = _mm_set1_epi32(1 << (shift - 1)); + __m128i sum64 = _mm_setzero_si128(); + const uint8_t *src = src8; + const uint8_t *dat = dat8; + int64_t err = 0; + if (params->r[0] > 0 && params->r[1] > 0) { + __m128i xq_coeff = pair_set_epi16(xq[0], xq[1]); + for (i = 0; i < height; ++i) { + __m128i sum32 = _mm_setzero_si128(); + for (j = 0; j < width - 8; j += 8) { + const __m128i d0 = _mm_cvtepu8_epi16(xx_loadl_64(dat + j)); + const __m128i s0 = _mm_cvtepu8_epi16(xx_loadl_64(src + j)); + const __m128i flt0_16b = + _mm_packs_epi32(xx_loadu_128(flt0 + j), xx_loadu_128(flt0 + j + 4)); + const __m128i flt1_16b = + _mm_packs_epi32(xx_loadu_128(flt1 + j), xx_loadu_128(flt1 + j + 4)); + const __m128i u0 = _mm_slli_epi16(d0, SGRPROJ_RST_BITS); + const __m128i flt0_0_sub_u = _mm_sub_epi16(flt0_16b, u0); + const __m128i flt1_0_sub_u = _mm_sub_epi16(flt1_16b, u0); + const __m128i v0 = _mm_madd_epi16( + xq_coeff, _mm_unpacklo_epi16(flt0_0_sub_u, flt1_0_sub_u)); + const __m128i v1 = _mm_madd_epi16( + xq_coeff, _mm_unpackhi_epi16(flt0_0_sub_u, flt1_0_sub_u)); + const __m128i vr0 = _mm_srai_epi32(_mm_add_epi32(v0, rounding), shift); + const __m128i vr1 = _mm_srai_epi32(_mm_add_epi32(v1, rounding), shift); + const __m128i e0 = + _mm_sub_epi16(_mm_add_epi16(_mm_packs_epi32(vr0, vr1), d0), s0); + const __m128i err0 = _mm_madd_epi16(e0, e0); + sum32 = _mm_add_epi32(sum32, err0); + } + for (k = j; k < width; ++k) { + const int32_t u = (int32_t)(dat[k] << SGRPROJ_RST_BITS); + int32_t v = xq[0] * (flt0[k] - u) + xq[1] * (flt1[k] - u); + const int32_t e = ROUND_POWER_OF_TWO(v, shift) + dat[k] - src[k]; + err += e * e; + } + dat += dat_stride; + src += src_stride; + flt0 += flt0_stride; + flt1 += flt1_stride; + const __m128i sum64_0 = _mm_cvtepi32_epi64(sum32); + const __m128i sum64_1 = _mm_cvtepi32_epi64(_mm_srli_si128(sum32, 8)); + sum64 = _mm_add_epi64(sum64, sum64_0); + sum64 = _mm_add_epi64(sum64, sum64_1); + } + } else if (params->r[0] > 0) { + __m128i xq_coeff = pair_set_epi16(xq[0], -(xq[0] << SGRPROJ_RST_BITS)); + for (i = 0; i < height; ++i) { + __m128i sum32 = _mm_setzero_si128(); + for (j = 0; j < width - 8; j += 8) { + const __m128i d0 = _mm_cvtepu8_epi16(xx_loadl_64(dat + j)); + const __m128i s0 = _mm_cvtepu8_epi16(xx_loadl_64(src + j)); + const __m128i flt0_16b = + _mm_packs_epi32(xx_loadu_128(flt0 + j), xx_loadu_128(flt0 + j + 4)); + const __m128i v0 = + _mm_madd_epi16(xq_coeff, _mm_unpacklo_epi16(flt0_16b, d0)); + const __m128i v1 = + _mm_madd_epi16(xq_coeff, _mm_unpackhi_epi16(flt0_16b, d0)); + const __m128i vr0 = _mm_srai_epi32(_mm_add_epi32(v0, rounding), shift); + const __m128i vr1 = _mm_srai_epi32(_mm_add_epi32(v1, rounding), shift); + const __m128i e0 = + _mm_sub_epi16(_mm_add_epi16(_mm_packs_epi32(vr0, vr1), d0), s0); + const __m128i err0 = _mm_madd_epi16(e0, e0); + sum32 = _mm_add_epi32(sum32, err0); + } + for (k = j; k < width; ++k) { + const int32_t u = (int32_t)(dat[k] << SGRPROJ_RST_BITS); + int32_t v = xq[0] * (flt0[k] - u); + const int32_t e = ROUND_POWER_OF_TWO(v, shift) + dat[k] - src[k]; + err += e * e; + } + dat += dat_stride; + src += src_stride; + flt0 += flt0_stride; + const __m128i sum64_0 = _mm_cvtepi32_epi64(sum32); + const __m128i sum64_1 = _mm_cvtepi32_epi64(_mm_srli_si128(sum32, 8)); + sum64 = _mm_add_epi64(sum64, sum64_0); + sum64 = _mm_add_epi64(sum64, sum64_1); + } + } else if (params->r[1] > 0) { + __m128i xq_coeff = pair_set_epi16(xq[1], -(xq[1] << SGRPROJ_RST_BITS)); + for (i = 0; i < height; ++i) { + __m128i sum32 = _mm_setzero_si128(); + for (j = 0; j < width - 8; j += 8) { + const __m128i d0 = _mm_cvtepu8_epi16(xx_loadl_64(dat + j)); + const __m128i s0 = _mm_cvtepu8_epi16(xx_loadl_64(src + j)); + const __m128i flt1_16b = + _mm_packs_epi32(xx_loadu_128(flt1 + j), xx_loadu_128(flt1 + j + 4)); + const __m128i v0 = + _mm_madd_epi16(xq_coeff, _mm_unpacklo_epi16(flt1_16b, d0)); + const __m128i v1 = + _mm_madd_epi16(xq_coeff, _mm_unpackhi_epi16(flt1_16b, d0)); + const __m128i vr0 = _mm_srai_epi32(_mm_add_epi32(v0, rounding), shift); + const __m128i vr1 = _mm_srai_epi32(_mm_add_epi32(v1, rounding), shift); + const __m128i e0 = + _mm_sub_epi16(_mm_add_epi16(_mm_packs_epi32(vr0, vr1), d0), s0); + const __m128i err0 = _mm_madd_epi16(e0, e0); + sum32 = _mm_add_epi32(sum32, err0); + } + for (k = j; k < width; ++k) { + const int32_t u = (int32_t)(dat[k] << SGRPROJ_RST_BITS); + int32_t v = xq[1] * (flt1[k] - u); + const int32_t e = ROUND_POWER_OF_TWO(v, shift) + dat[k] - src[k]; + err += e * e; + } + dat += dat_stride; + src += src_stride; + flt1 += flt1_stride; + const __m128i sum64_0 = _mm_cvtepi32_epi64(sum32); + const __m128i sum64_1 = _mm_cvtepi32_epi64(_mm_srli_si128(sum32, 8)); + sum64 = _mm_add_epi64(sum64, sum64_0); + sum64 = _mm_add_epi64(sum64, sum64_1); + } + } else { + __m128i sum32 = _mm_setzero_si128(); + for (i = 0; i < height; ++i) { + for (j = 0; j < width - 16; j += 16) { + const __m128i d = xx_loadu_128(dat + j); + const __m128i s = xx_loadu_128(src + j); + const __m128i d0 = _mm_cvtepu8_epi16(d); + const __m128i d1 = _mm_cvtepu8_epi16(_mm_srli_si128(d, 8)); + const __m128i s0 = _mm_cvtepu8_epi16(s); + const __m128i s1 = _mm_cvtepu8_epi16(_mm_srli_si128(s, 8)); + const __m128i diff0 = _mm_sub_epi16(d0, s0); + const __m128i diff1 = _mm_sub_epi16(d1, s1); + const __m128i err0 = _mm_madd_epi16(diff0, diff0); + const __m128i err1 = _mm_madd_epi16(diff1, diff1); + sum32 = _mm_add_epi32(sum32, err0); + sum32 = _mm_add_epi32(sum32, err1); + } + for (k = j; k < width; ++k) { + const int32_t e = (int32_t)(dat[k]) - src[k]; + err += e * e; + } + dat += dat_stride; + src += src_stride; + } + const __m128i sum64_0 = _mm_cvtepi32_epi64(sum32); + const __m128i sum64_1 = _mm_cvtepi32_epi64(_mm_srli_si128(sum32, 8)); + sum64 = _mm_add_epi64(sum64_0, sum64_1); + } + int64_t sum[2]; + xx_storeu_128(sum, sum64); + err += sum[0] + sum[1]; + return err; +}
diff --git a/test/pickrst_test.cc b/test/pickrst_test.cc new file mode 100644 index 0000000..040e8e8 --- /dev/null +++ b/test/pickrst_test.cc
@@ -0,0 +1,187 @@ +/* + * Copyright (c) 2018, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include "third_party/googletest/src/googletest/include/gtest/gtest.h" + +#include "test/function_equivalence_test.h" +#include "test/register_state_check.h" + +#include "config/aom_config.h" +#include "config/aom_dsp_rtcd.h" + +#include "aom/aom_integer.h" +#include "av1/encoder/pickrst.h" +using libaom_test::FunctionEquivalenceTest; + +#define MAX_DATA_BLOCK 384 + +namespace { +static const int kIterations = 100; + +typedef int64_t (*lowbd_pixel_proj_error_func)( + const uint8_t *src8, int width, int height, int src_stride, + const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride, + int32_t *flt1, int flt1_stride, int xq[2], const sgr_params_type *params); + +typedef libaom_test::FuncParam<lowbd_pixel_proj_error_func> TestFuncs; + +//////////////////////////////////////////////////////////////////////////////// +// 8 bit +//////////////////////////////////////////////////////////////////////////////// + +typedef ::testing::tuple<const lowbd_pixel_proj_error_func> + PixelProjErrorTestParam; + +class PixelProjErrorTest + : public ::testing::TestWithParam<PixelProjErrorTestParam> { + public: + virtual void SetUp() { + target_func_ = GET_PARAM(0); + src_ = (uint8_t *)(aom_malloc(MAX_DATA_BLOCK * MAX_DATA_BLOCK * + sizeof(uint8_t))); + dgd_ = (uint8_t *)(aom_malloc(MAX_DATA_BLOCK * MAX_DATA_BLOCK * + sizeof(uint8_t))); + flt0_ = (int32_t *)(aom_malloc(MAX_DATA_BLOCK * MAX_DATA_BLOCK * + sizeof(int32_t))); + flt1_ = (int32_t *)(aom_malloc(MAX_DATA_BLOCK * MAX_DATA_BLOCK * + sizeof(int32_t))); + } + virtual void TearDown() { + aom_free(src_); + aom_free(dgd_); + aom_free(flt0_); + aom_free(flt1_); + } + void runPixelProjErrorTest(int32_t run_times); + void runPixelProjErrorTest_ExtremeValues(); + + private: + lowbd_pixel_proj_error_func target_func_; + ACMRandom rng_; + uint8_t *src_; + uint8_t *dgd_; + int32_t *flt0_; + int32_t *flt1_; +}; + +void PixelProjErrorTest::runPixelProjErrorTest(int32_t run_times) { + int h_end = run_times != 1 ? 128 : (rng_.Rand16() % MAX_DATA_BLOCK) + 1; + int v_end = run_times != 1 ? 128 : (rng_.Rand16() % MAX_DATA_BLOCK) + 1; + const int dgd_stride = MAX_DATA_BLOCK; + const int src_stride = MAX_DATA_BLOCK; + const int flt0_stride = MAX_DATA_BLOCK; + const int flt1_stride = MAX_DATA_BLOCK; + sgr_params_type params; + int xq[2]; + const int iters = run_times == 1 ? kIterations : 4; + for (int iter = 0; iter < iters && !HasFatalFailure(); ++iter) { + int64_t err_ref = 0, err_test = 1; + for (int i = 0; i < MAX_DATA_BLOCK * MAX_DATA_BLOCK; ++i) { + dgd_[i] = rng_.Rand8(); + src_[i] = rng_.Rand8(); + flt0_[i] = rng_.Rand15Signed(); + flt1_[i] = rng_.Rand15Signed(); + } + xq[0] = rng_.Rand8() % (1 << SGRPROJ_PRJ_BITS); + xq[1] = rng_.Rand8() % (1 << SGRPROJ_PRJ_BITS); + params.r[0] = run_times == 1 ? (rng_.Rand8() % MAX_RADIUS) : (iter % 2); + params.r[1] = run_times == 1 ? (rng_.Rand8() % MAX_RADIUS) : (iter / 2); + params.s[0] = run_times == 1 ? (rng_.Rand8() % MAX_RADIUS) : (iter % 2); + params.s[1] = run_times == 1 ? (rng_.Rand8() % MAX_RADIUS) : (iter / 2); + uint8_t *dgd = dgd_; + uint8_t *src = src_; + + aom_usec_timer timer; + aom_usec_timer_start(&timer); + for (int i = 0; i < run_times; ++i) { + err_ref = av1_lowbd_pixel_proj_error_c(src, h_end, v_end, src_stride, dgd, + dgd_stride, flt0_, flt0_stride, + flt1_, flt1_stride, xq, ¶ms); + } + aom_usec_timer_mark(&timer); + const double time1 = static_cast<double>(aom_usec_timer_elapsed(&timer)); + aom_usec_timer_start(&timer); + for (int i = 0; i < run_times; ++i) { + err_test = + target_func_(src, h_end, v_end, src_stride, dgd, dgd_stride, flt0_, + flt0_stride, flt1_, flt1_stride, xq, ¶ms); + } + aom_usec_timer_mark(&timer); + const double time2 = static_cast<double>(aom_usec_timer_elapsed(&timer)); + if (run_times > 10) { + printf("r0 %d r1 %d %3dx%-3d:%7.2f/%7.2fns (%3.2f)\n", params.r[0], + params.r[1], h_end, v_end, time1, time2, time1 / time2); + } + ASSERT_EQ(err_ref, err_test); + } +} + +void PixelProjErrorTest::runPixelProjErrorTest_ExtremeValues() { + const int h_start = 0; + int h_end = 192; + const int v_start = 0; + int v_end = 192; + const int dgd_stride = MAX_DATA_BLOCK; + const int src_stride = MAX_DATA_BLOCK; + const int flt0_stride = MAX_DATA_BLOCK; + const int flt1_stride = MAX_DATA_BLOCK; + sgr_params_type params; + int xq[2]; + const int iters = kIterations; + for (int iter = 0; iter < iters && !HasFatalFailure(); ++iter) { + int64_t err_ref = 0, err_test = 1; + for (int i = 0; i < MAX_DATA_BLOCK * MAX_DATA_BLOCK; ++i) { + dgd_[i] = 0; + src_[i] = 255; + flt0_[i] = rng_.Rand15Signed(); + flt1_[i] = rng_.Rand15Signed(); + } + xq[0] = rng_.Rand8() % (1 << SGRPROJ_PRJ_BITS); + xq[1] = rng_.Rand8() % (1 << SGRPROJ_PRJ_BITS); + params.r[0] = rng_.Rand8() % MAX_RADIUS; + params.r[1] = rng_.Rand8() % MAX_RADIUS; + params.s[0] = rng_.Rand8() % MAX_RADIUS; + params.s[1] = rng_.Rand8() % MAX_RADIUS; + uint8_t *dgd = dgd_; + uint8_t *src = src_; + + err_ref = av1_lowbd_pixel_proj_error_c( + src, h_end - h_start, v_end - v_start, src_stride, dgd, dgd_stride, + flt0_, flt0_stride, flt1_, flt1_stride, xq, ¶ms); + + err_test = target_func_(src, h_end - h_start, v_end - v_start, src_stride, + dgd, dgd_stride, flt0_, flt0_stride, flt1_, + flt1_stride, xq, ¶ms); + + ASSERT_EQ(err_ref, err_test); + } +} + +TEST_P(PixelProjErrorTest, RandomValues) { runPixelProjErrorTest(1); } + +TEST_P(PixelProjErrorTest, ExtremeValues) { + runPixelProjErrorTest_ExtremeValues(); +} + +TEST_P(PixelProjErrorTest, DISABLED_Speed) { runPixelProjErrorTest(200000); } + +#if HAVE_SSE4_1 +INSTANTIATE_TEST_CASE_P(SSE4_1, PixelProjErrorTest, + ::testing::Values(av1_lowbd_pixel_proj_error_sse4_1)); +#endif // HAVE_SSE4_1 + +#if HAVE_AVX2 + +INSTANTIATE_TEST_CASE_P(AVX2, PixelProjErrorTest, + ::testing::Values(av1_lowbd_pixel_proj_error_avx2)); +#endif // HAVE_AVX2 + +} // namespace
diff --git a/test/test.cmake b/test/test.cmake index a6b7eaa..1aa08d3 100644 --- a/test/test.cmake +++ b/test/test.cmake
@@ -72,6 +72,7 @@ "${AOM_ROOT}/test/resize_test.cc" "${AOM_ROOT}/test/scalability_test.cc" "${AOM_ROOT}/test/wiener_test.cc" + "${AOM_ROOT}/test/pickrst_test.cc" "${AOM_ROOT}/test/y4m_test.cc" "${AOM_ROOT}/test/y4m_video_source.h" "${AOM_ROOT}/test/yuv_video_source.h")