convolve_y/jnt_convolve_y support for round_1 > 0
The case when round_0 + round_1 > FILTER_BITS used to
result in negative shifts.
Change-Id: I25ec3fb187432f4f34cffab5a01158621ecf3503
diff --git a/av1/common/convolve.c b/av1/common/convolve.c
index 20a02cd..7645653 100644
--- a/av1/common/convolve.c
+++ b/av1/common/convolve.c
@@ -441,7 +441,7 @@
CONV_BUF_TYPE *dst = conv_params->dst;
int dst_stride = conv_params->dst_stride;
const int fo_vert = filter_params_y->taps / 2 - 1;
- const int bits = FILTER_BITS - conv_params->round_0 - conv_params->round_1;
+ const int bits = FILTER_BITS - conv_params->round_0;
(void)filter_params_x;
(void)subpel_x_q4;
(void)dst0;
@@ -456,12 +456,8 @@
for (int k = 0; k < filter_params_y->taps; ++k) {
res += y_filter[k] * src[(y - fo_vert + k) * src_stride + x];
}
-#if CONFIG_LOWPRECISION_BLEND
- if (bits < 0)
- res = ROUND_POWER_OF_TWO(res, bits);
- else
-#endif // CONFIG_LOWPRECISION_BLEND
- res *= (1 << bits);
+ res *= (1 << bits);
+ res = ROUND_POWER_OF_TWO(res, conv_params->round_1);
if (conv_params->do_average)
dst[y * dst_stride + x] += res;
else
@@ -730,7 +726,7 @@
CONV_BUF_TYPE *dst = conv_params->dst;
int dst_stride = conv_params->dst_stride;
const int fo_vert = filter_params_y->taps / 2 - 1;
- const int bits = FILTER_BITS - conv_params->round_0 - conv_params->round_1;
+ const int bits = FILTER_BITS - conv_params->round_0;
(void)filter_params_x;
(void)subpel_x_q4;
(void)dst0;
@@ -745,12 +741,8 @@
for (int k = 0; k < filter_params_y->taps; ++k) {
res += y_filter[k] * src[(y - fo_vert + k) * src_stride + x];
}
-#if CONFIG_LOWPRECISION_BLEND
- if (bits < 0)
- res = ROUND_POWER_OF_TWO(res, bits);
- else
-#endif // CONFIG_LOWPRECISION_BLEND
- res *= (1 << bits);
+ res *= (1 << bits);
+ res = ROUND_POWER_OF_TWO(res, conv_params->round_1);
if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) {
dst[y * dst_stride + x] += res * conv_params->bck_offset;
diff --git a/av1/common/x86/convolve_avx2.c b/av1/common/x86/convolve_avx2.c
index 7fd76f8..b450ff2 100644
--- a/av1/common/x86/convolve_avx2.c
+++ b/av1/common/x86/convolve_avx2.c
@@ -350,8 +350,12 @@
int i, j;
const int fo_vert = filter_params_y->taps / 2 - 1;
const uint8_t *const src_ptr = src - fo_vert * src_stride;
- const int bits =
- FILTER_BITS - conv_params->round_0 - (conv_params->round_1 - 1);
+ // +1 to compensate for dividing the filter coeffs by 2
+ const int left_shift = FILTER_BITS - conv_params->round_0 + 1;
+ const __m256i round_const =
+ _mm256_set1_epi32((1 << conv_params->round_1) >> 1);
+ const __m128i round_shift = _mm_cvtsi32_si128(conv_params->round_1);
+
const __m256i avg_mask = _mm256_set1_epi32(conv_params->do_average ? -1 : 0);
__m256i coeffs[4], s[8];
@@ -439,16 +443,22 @@
const __m256i res_lo_0_32b =
_mm256_cvtepi16_epi32(_mm256_castsi256_si128(res_lo));
- const __m256i res_lo_0_shift = _mm256_slli_epi32(res_lo_0_32b, bits);
+ const __m256i res_lo_0_shift =
+ _mm256_slli_epi32(res_lo_0_32b, left_shift);
+ const __m256i res_lo_0_round = _mm256_sra_epi32(
+ _mm256_add_epi32(res_lo_0_shift, round_const), round_shift);
// Accumulate values into the destination buffer
- add_store_aligned(&dst[i * dst_stride + j], &res_lo_0_shift, &avg_mask);
+ add_store_aligned(&dst[i * dst_stride + j], &res_lo_0_round, &avg_mask);
const __m256i res_lo_1_32b =
_mm256_cvtepi16_epi32(_mm256_extracti128_si256(res_lo, 1));
- const __m256i res_lo_1_shift = _mm256_slli_epi32(res_lo_1_32b, bits);
+ const __m256i res_lo_1_shift =
+ _mm256_slli_epi32(res_lo_1_32b, left_shift);
+ const __m256i res_lo_1_round = _mm256_sra_epi32(
+ _mm256_add_epi32(res_lo_1_shift, round_const), round_shift);
- add_store_aligned(&dst[i * dst_stride + j + dst_stride], &res_lo_1_shift,
+ add_store_aligned(&dst[i * dst_stride + j + dst_stride], &res_lo_1_round,
&avg_mask);
if (w - j > 8) {
@@ -456,17 +466,23 @@
const __m256i res_hi_0_32b =
_mm256_cvtepi16_epi32(_mm256_castsi256_si128(res_hi));
- const __m256i res_hi_0_shift = _mm256_slli_epi32(res_hi_0_32b, bits);
+ const __m256i res_hi_0_shift =
+ _mm256_slli_epi32(res_hi_0_32b, left_shift);
+ const __m256i res_hi_0_round = _mm256_sra_epi32(
+ _mm256_add_epi32(res_hi_0_shift, round_const), round_shift);
- add_store_aligned(&dst[i * dst_stride + j + 8], &res_hi_0_shift,
+ add_store_aligned(&dst[i * dst_stride + j + 8], &res_hi_0_round,
&avg_mask);
const __m256i res_hi_1_32b =
_mm256_cvtepi16_epi32(_mm256_extracti128_si256(res_hi, 1));
- const __m256i res_hi_1_shift = _mm256_slli_epi32(res_hi_1_32b, bits);
+ const __m256i res_hi_1_shift =
+ _mm256_slli_epi32(res_hi_1_32b, left_shift);
+ const __m256i res_hi_1_round = _mm256_sra_epi32(
+ _mm256_add_epi32(res_hi_1_shift, round_const), round_shift);
add_store_aligned(&dst[i * dst_stride + j + 8 + dst_stride],
- &res_hi_1_shift, &avg_mask);
+ &res_hi_1_round, &avg_mask);
}
s[0] = s[1];
s[1] = s[2];
diff --git a/av1/common/x86/convolve_sse2.c b/av1/common/x86/convolve_sse2.c
index d384b70..f8081d2 100644
--- a/av1/common/x86/convolve_sse2.c
+++ b/av1/common/x86/convolve_sse2.c
@@ -93,8 +93,10 @@
const int dst_stride = conv_params->dst_stride;
const int fo_vert = filter_params_y->taps / 2 - 1;
const uint8_t *src_ptr = src - fo_vert * src_stride;
- const int bits = FILTER_BITS - conv_params->round_0 - conv_params->round_1;
+ const int bits = FILTER_BITS - conv_params->round_0;
const __m128i left_shift = _mm_cvtsi32_si128(bits);
+ const __m128i round_const = _mm_set1_epi32((1 << conv_params->round_1) >> 1);
+ const __m128i round_shift = _mm_cvtsi32_si128(conv_params->round_1);
const __m128i avg_mask = _mm_set1_epi32(conv_params->do_average ? -1 : 0);
__m128i coeffs[4];
@@ -135,12 +137,16 @@
res = convolve_lo_y(s + 0, coeffs);
res_shift = _mm_sll_epi32(res, left_shift);
+ res_shift =
+ _mm_sra_epi32(_mm_add_epi32(res_shift, round_const), round_shift);
add_store(dst, &res_shift, &avg_mask);
src_ptr += src_stride;
dst += dst_stride;
res = convolve_lo_y(s + 1, coeffs);
res_shift = _mm_sll_epi32(res, left_shift);
+ res_shift =
+ _mm_sra_epi32(_mm_add_epi32(res_shift, round_const), round_shift);
add_store(dst, &res_shift, &avg_mask);
src_ptr += src_stride;
dst += dst_stride;
@@ -192,6 +198,10 @@
res_hi = convolve_hi_y(s, coeffs); // Filter high index pixels
res_lo_shift = _mm_sll_epi32(res_lo, left_shift);
res_hi_shift = _mm_sll_epi32(res_hi, left_shift);
+ res_lo_shift = _mm_sra_epi32(_mm_add_epi32(res_lo_shift, round_const),
+ round_shift);
+ res_hi_shift = _mm_sra_epi32(_mm_add_epi32(res_hi_shift, round_const),
+ round_shift);
add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask);
add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &avg_mask);
i++;
@@ -200,6 +210,10 @@
res_hi = convolve_hi_y(s + 1, coeffs); // Filter high index pixels
res_lo_shift = _mm_sll_epi32(res_lo, left_shift);
res_hi_shift = _mm_sll_epi32(res_hi, left_shift);
+ res_lo_shift = _mm_sra_epi32(_mm_add_epi32(res_lo_shift, round_const),
+ round_shift);
+ res_hi_shift = _mm_sra_epi32(_mm_add_epi32(res_hi_shift, round_const),
+ round_shift);
add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask);
add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &avg_mask);
i++;
diff --git a/av1/common/x86/jnt_convolve_sse4.c b/av1/common/x86/jnt_convolve_sse4.c
index 42404d4..9de2744 100644
--- a/av1/common/x86/jnt_convolve_sse4.c
+++ b/av1/common/x86/jnt_convolve_sse4.c
@@ -107,12 +107,14 @@
const int dst_stride = conv_params->dst_stride;
const int fo_vert = filter_params_y->taps / 2 - 1;
const uint8_t *src_ptr = src - fo_vert * src_stride;
- const int bits = FILTER_BITS - conv_params->round_0 - conv_params->round_1;
+ const int bits = FILTER_BITS - conv_params->round_0;
const __m128i left_shift = _mm_cvtsi32_si128(bits);
const __m128i avg_mask = _mm_set1_epi32(conv_params->do_average ? -1 : 0);
const __m128i wt0 = _mm_set1_epi32(conv_params->fwd_offset);
const __m128i wt1 = _mm_set1_epi32(conv_params->bck_offset);
const __m128i wt = conv_params->do_average ? wt1 : wt0;
+ const __m128i round_const = _mm_set1_epi32((1 << conv_params->round_1) >> 1);
+ const __m128i round_shift = _mm_cvtsi32_si128(conv_params->round_1);
__m128i coeffs[4];
(void)filter_params_x;
@@ -152,6 +154,8 @@
res = convolve_lo_y(s + 0, coeffs);
res_shift = _mm_sll_epi32(res, left_shift);
+ res_shift =
+ _mm_sra_epi32(_mm_add_epi32(res_shift, round_const), round_shift);
if (conv_params->use_jnt_comp_avg)
mult_add_store(dst, &res_shift, &avg_mask, &wt,
conv_params->do_average);
@@ -162,6 +166,8 @@
res = convolve_lo_y(s + 1, coeffs);
res_shift = _mm_sll_epi32(res, left_shift);
+ res_shift =
+ _mm_sra_epi32(_mm_add_epi32(res_shift, round_const), round_shift);
if (conv_params->use_jnt_comp_avg)
mult_add_store(dst, &res_shift, &avg_mask, &wt,
conv_params->do_average);
@@ -217,6 +223,10 @@
res_hi = convolve_hi_y(s, coeffs); // Filter high index pixels
res_lo_shift = _mm_sll_epi32(res_lo, left_shift);
res_hi_shift = _mm_sll_epi32(res_hi, left_shift);
+ res_lo_shift = _mm_sra_epi32(_mm_add_epi32(res_lo_shift, round_const),
+ round_shift);
+ res_hi_shift = _mm_sra_epi32(_mm_add_epi32(res_hi_shift, round_const),
+ round_shift);
if (conv_params->use_jnt_comp_avg) {
mult_add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask,
&wt, conv_params->do_average);
@@ -232,6 +242,10 @@
res_hi = convolve_hi_y(s + 1, coeffs); // Filter high index pixels
res_lo_shift = _mm_sll_epi32(res_lo, left_shift);
res_hi_shift = _mm_sll_epi32(res_hi, left_shift);
+ res_lo_shift = _mm_sra_epi32(_mm_add_epi32(res_lo_shift, round_const),
+ round_shift);
+ res_hi_shift = _mm_sra_epi32(_mm_add_epi32(res_hi_shift, round_const),
+ round_shift);
if (conv_params->use_jnt_comp_avg) {
mult_add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask,
&wt, conv_params->do_average);
diff --git a/test/av1_convolve_2d_test.cc b/test/av1_convolve_2d_test.cc
index 04ceb79..50154e9 100644
--- a/test/av1_convolve_2d_test.cc
+++ b/test/av1_convolve_2d_test.cc
@@ -82,10 +82,18 @@
av1_convolve_2d_copy_sr_sse2, 0, 0, 1));
INSTANTIATE_TEST_CASE_P(
+ C_X, AV1Convolve2DSrTest,
+ libaom_test::AV1Convolve2D::BuildParams(av1_convolve_x_sr_c, 1, 0, 0));
+
+INSTANTIATE_TEST_CASE_P(
SSE2_X, AV1Convolve2DSrTest,
libaom_test::AV1Convolve2D::BuildParams(av1_convolve_x_sr_sse2, 1, 0, 0));
INSTANTIATE_TEST_CASE_P(
+ C_Y, AV1Convolve2DSrTest,
+ libaom_test::AV1Convolve2D::BuildParams(av1_convolve_y_sr_c, 0, 1, 0));
+
+INSTANTIATE_TEST_CASE_P(
SSE2_Y, AV1Convolve2DSrTest,
libaom_test::AV1Convolve2D::BuildParams(av1_convolve_y_sr_sse2, 0, 1, 0));