convolve_y/jnt_convolve_y support for round_1 > 0 The case when round_0 + round_1 > FILTER_BITS used to result in negative shifts. Change-Id: I25ec3fb187432f4f34cffab5a01158621ecf3503
diff --git a/av1/common/convolve.c b/av1/common/convolve.c index 20a02cd..7645653 100644 --- a/av1/common/convolve.c +++ b/av1/common/convolve.c
@@ -441,7 +441,7 @@ CONV_BUF_TYPE *dst = conv_params->dst; int dst_stride = conv_params->dst_stride; const int fo_vert = filter_params_y->taps / 2 - 1; - const int bits = FILTER_BITS - conv_params->round_0 - conv_params->round_1; + const int bits = FILTER_BITS - conv_params->round_0; (void)filter_params_x; (void)subpel_x_q4; (void)dst0; @@ -456,12 +456,8 @@ for (int k = 0; k < filter_params_y->taps; ++k) { res += y_filter[k] * src[(y - fo_vert + k) * src_stride + x]; } -#if CONFIG_LOWPRECISION_BLEND - if (bits < 0) - res = ROUND_POWER_OF_TWO(res, bits); - else -#endif // CONFIG_LOWPRECISION_BLEND - res *= (1 << bits); + res *= (1 << bits); + res = ROUND_POWER_OF_TWO(res, conv_params->round_1); if (conv_params->do_average) dst[y * dst_stride + x] += res; else @@ -730,7 +726,7 @@ CONV_BUF_TYPE *dst = conv_params->dst; int dst_stride = conv_params->dst_stride; const int fo_vert = filter_params_y->taps / 2 - 1; - const int bits = FILTER_BITS - conv_params->round_0 - conv_params->round_1; + const int bits = FILTER_BITS - conv_params->round_0; (void)filter_params_x; (void)subpel_x_q4; (void)dst0; @@ -745,12 +741,8 @@ for (int k = 0; k < filter_params_y->taps; ++k) { res += y_filter[k] * src[(y - fo_vert + k) * src_stride + x]; } -#if CONFIG_LOWPRECISION_BLEND - if (bits < 0) - res = ROUND_POWER_OF_TWO(res, bits); - else -#endif // CONFIG_LOWPRECISION_BLEND - res *= (1 << bits); + res *= (1 << bits); + res = ROUND_POWER_OF_TWO(res, conv_params->round_1); if (conv_params->use_jnt_comp_avg) { if (conv_params->do_average) { dst[y * dst_stride + x] += res * conv_params->bck_offset;
diff --git a/av1/common/x86/convolve_avx2.c b/av1/common/x86/convolve_avx2.c index 7fd76f8..b450ff2 100644 --- a/av1/common/x86/convolve_avx2.c +++ b/av1/common/x86/convolve_avx2.c
@@ -350,8 +350,12 @@ int i, j; const int fo_vert = filter_params_y->taps / 2 - 1; const uint8_t *const src_ptr = src - fo_vert * src_stride; - const int bits = - FILTER_BITS - conv_params->round_0 - (conv_params->round_1 - 1); + // +1 to compensate for dividing the filter coeffs by 2 + const int left_shift = FILTER_BITS - conv_params->round_0 + 1; + const __m256i round_const = + _mm256_set1_epi32((1 << conv_params->round_1) >> 1); + const __m128i round_shift = _mm_cvtsi32_si128(conv_params->round_1); + const __m256i avg_mask = _mm256_set1_epi32(conv_params->do_average ? -1 : 0); __m256i coeffs[4], s[8]; @@ -439,16 +443,22 @@ const __m256i res_lo_0_32b = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(res_lo)); - const __m256i res_lo_0_shift = _mm256_slli_epi32(res_lo_0_32b, bits); + const __m256i res_lo_0_shift = + _mm256_slli_epi32(res_lo_0_32b, left_shift); + const __m256i res_lo_0_round = _mm256_sra_epi32( + _mm256_add_epi32(res_lo_0_shift, round_const), round_shift); // Accumulate values into the destination buffer - add_store_aligned(&dst[i * dst_stride + j], &res_lo_0_shift, &avg_mask); + add_store_aligned(&dst[i * dst_stride + j], &res_lo_0_round, &avg_mask); const __m256i res_lo_1_32b = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(res_lo, 1)); - const __m256i res_lo_1_shift = _mm256_slli_epi32(res_lo_1_32b, bits); + const __m256i res_lo_1_shift = + _mm256_slli_epi32(res_lo_1_32b, left_shift); + const __m256i res_lo_1_round = _mm256_sra_epi32( + _mm256_add_epi32(res_lo_1_shift, round_const), round_shift); - add_store_aligned(&dst[i * dst_stride + j + dst_stride], &res_lo_1_shift, + add_store_aligned(&dst[i * dst_stride + j + dst_stride], &res_lo_1_round, &avg_mask); if (w - j > 8) { @@ -456,17 +466,23 @@ const __m256i res_hi_0_32b = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(res_hi)); - const __m256i res_hi_0_shift = _mm256_slli_epi32(res_hi_0_32b, bits); + const __m256i res_hi_0_shift = + _mm256_slli_epi32(res_hi_0_32b, left_shift); + const __m256i res_hi_0_round = _mm256_sra_epi32( + _mm256_add_epi32(res_hi_0_shift, round_const), round_shift); - add_store_aligned(&dst[i * dst_stride + j + 8], &res_hi_0_shift, + add_store_aligned(&dst[i * dst_stride + j + 8], &res_hi_0_round, &avg_mask); const __m256i res_hi_1_32b = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(res_hi, 1)); - const __m256i res_hi_1_shift = _mm256_slli_epi32(res_hi_1_32b, bits); + const __m256i res_hi_1_shift = + _mm256_slli_epi32(res_hi_1_32b, left_shift); + const __m256i res_hi_1_round = _mm256_sra_epi32( + _mm256_add_epi32(res_hi_1_shift, round_const), round_shift); add_store_aligned(&dst[i * dst_stride + j + 8 + dst_stride], - &res_hi_1_shift, &avg_mask); + &res_hi_1_round, &avg_mask); } s[0] = s[1]; s[1] = s[2];
diff --git a/av1/common/x86/convolve_sse2.c b/av1/common/x86/convolve_sse2.c index d384b70..f8081d2 100644 --- a/av1/common/x86/convolve_sse2.c +++ b/av1/common/x86/convolve_sse2.c
@@ -93,8 +93,10 @@ const int dst_stride = conv_params->dst_stride; const int fo_vert = filter_params_y->taps / 2 - 1; const uint8_t *src_ptr = src - fo_vert * src_stride; - const int bits = FILTER_BITS - conv_params->round_0 - conv_params->round_1; + const int bits = FILTER_BITS - conv_params->round_0; const __m128i left_shift = _mm_cvtsi32_si128(bits); + const __m128i round_const = _mm_set1_epi32((1 << conv_params->round_1) >> 1); + const __m128i round_shift = _mm_cvtsi32_si128(conv_params->round_1); const __m128i avg_mask = _mm_set1_epi32(conv_params->do_average ? -1 : 0); __m128i coeffs[4]; @@ -135,12 +137,16 @@ res = convolve_lo_y(s + 0, coeffs); res_shift = _mm_sll_epi32(res, left_shift); + res_shift = + _mm_sra_epi32(_mm_add_epi32(res_shift, round_const), round_shift); add_store(dst, &res_shift, &avg_mask); src_ptr += src_stride; dst += dst_stride; res = convolve_lo_y(s + 1, coeffs); res_shift = _mm_sll_epi32(res, left_shift); + res_shift = + _mm_sra_epi32(_mm_add_epi32(res_shift, round_const), round_shift); add_store(dst, &res_shift, &avg_mask); src_ptr += src_stride; dst += dst_stride; @@ -192,6 +198,10 @@ res_hi = convolve_hi_y(s, coeffs); // Filter high index pixels res_lo_shift = _mm_sll_epi32(res_lo, left_shift); res_hi_shift = _mm_sll_epi32(res_hi, left_shift); + res_lo_shift = _mm_sra_epi32(_mm_add_epi32(res_lo_shift, round_const), + round_shift); + res_hi_shift = _mm_sra_epi32(_mm_add_epi32(res_hi_shift, round_const), + round_shift); add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask); add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &avg_mask); i++; @@ -200,6 +210,10 @@ res_hi = convolve_hi_y(s + 1, coeffs); // Filter high index pixels res_lo_shift = _mm_sll_epi32(res_lo, left_shift); res_hi_shift = _mm_sll_epi32(res_hi, left_shift); + res_lo_shift = _mm_sra_epi32(_mm_add_epi32(res_lo_shift, round_const), + round_shift); + res_hi_shift = _mm_sra_epi32(_mm_add_epi32(res_hi_shift, round_const), + round_shift); add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask); add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &avg_mask); i++;
diff --git a/av1/common/x86/jnt_convolve_sse4.c b/av1/common/x86/jnt_convolve_sse4.c index 42404d4..9de2744 100644 --- a/av1/common/x86/jnt_convolve_sse4.c +++ b/av1/common/x86/jnt_convolve_sse4.c
@@ -107,12 +107,14 @@ const int dst_stride = conv_params->dst_stride; const int fo_vert = filter_params_y->taps / 2 - 1; const uint8_t *src_ptr = src - fo_vert * src_stride; - const int bits = FILTER_BITS - conv_params->round_0 - conv_params->round_1; + const int bits = FILTER_BITS - conv_params->round_0; const __m128i left_shift = _mm_cvtsi32_si128(bits); const __m128i avg_mask = _mm_set1_epi32(conv_params->do_average ? -1 : 0); const __m128i wt0 = _mm_set1_epi32(conv_params->fwd_offset); const __m128i wt1 = _mm_set1_epi32(conv_params->bck_offset); const __m128i wt = conv_params->do_average ? wt1 : wt0; + const __m128i round_const = _mm_set1_epi32((1 << conv_params->round_1) >> 1); + const __m128i round_shift = _mm_cvtsi32_si128(conv_params->round_1); __m128i coeffs[4]; (void)filter_params_x; @@ -152,6 +154,8 @@ res = convolve_lo_y(s + 0, coeffs); res_shift = _mm_sll_epi32(res, left_shift); + res_shift = + _mm_sra_epi32(_mm_add_epi32(res_shift, round_const), round_shift); if (conv_params->use_jnt_comp_avg) mult_add_store(dst, &res_shift, &avg_mask, &wt, conv_params->do_average); @@ -162,6 +166,8 @@ res = convolve_lo_y(s + 1, coeffs); res_shift = _mm_sll_epi32(res, left_shift); + res_shift = + _mm_sra_epi32(_mm_add_epi32(res_shift, round_const), round_shift); if (conv_params->use_jnt_comp_avg) mult_add_store(dst, &res_shift, &avg_mask, &wt, conv_params->do_average); @@ -217,6 +223,10 @@ res_hi = convolve_hi_y(s, coeffs); // Filter high index pixels res_lo_shift = _mm_sll_epi32(res_lo, left_shift); res_hi_shift = _mm_sll_epi32(res_hi, left_shift); + res_lo_shift = _mm_sra_epi32(_mm_add_epi32(res_lo_shift, round_const), + round_shift); + res_hi_shift = _mm_sra_epi32(_mm_add_epi32(res_hi_shift, round_const), + round_shift); if (conv_params->use_jnt_comp_avg) { mult_add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask, &wt, conv_params->do_average); @@ -232,6 +242,10 @@ res_hi = convolve_hi_y(s + 1, coeffs); // Filter high index pixels res_lo_shift = _mm_sll_epi32(res_lo, left_shift); res_hi_shift = _mm_sll_epi32(res_hi, left_shift); + res_lo_shift = _mm_sra_epi32(_mm_add_epi32(res_lo_shift, round_const), + round_shift); + res_hi_shift = _mm_sra_epi32(_mm_add_epi32(res_hi_shift, round_const), + round_shift); if (conv_params->use_jnt_comp_avg) { mult_add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask, &wt, conv_params->do_average);
diff --git a/test/av1_convolve_2d_test.cc b/test/av1_convolve_2d_test.cc index 04ceb79..50154e9 100644 --- a/test/av1_convolve_2d_test.cc +++ b/test/av1_convolve_2d_test.cc
@@ -82,10 +82,18 @@ av1_convolve_2d_copy_sr_sse2, 0, 0, 1)); INSTANTIATE_TEST_CASE_P( + C_X, AV1Convolve2DSrTest, + libaom_test::AV1Convolve2D::BuildParams(av1_convolve_x_sr_c, 1, 0, 0)); + +INSTANTIATE_TEST_CASE_P( SSE2_X, AV1Convolve2DSrTest, libaom_test::AV1Convolve2D::BuildParams(av1_convolve_x_sr_sse2, 1, 0, 0)); INSTANTIATE_TEST_CASE_P( + C_Y, AV1Convolve2DSrTest, + libaom_test::AV1Convolve2D::BuildParams(av1_convolve_y_sr_c, 0, 1, 0)); + +INSTANTIATE_TEST_CASE_P( SSE2_Y, AV1Convolve2DSrTest, libaom_test::AV1Convolve2D::BuildParams(av1_convolve_y_sr_sse2, 0, 1, 0));