[jnt-comp, normative] Avoid double-rounding in prediction
As per the linked bug report, the distance-weighted compound
prediction has two separate round operations, first by 3
bits (inside the various convolve functions), then by 10 bits
(after the convolution functions).
We can improve on this by right shifting by 3 bits inside the
convolve functions - this is equivalent to doing a single round
by 13 bits at the end.
Note: In the encoder, when doing joint_motion_search(), we do
things a bit differently: So that we can try modifying the two
"sides" of the prediction independently, we predict each side as
if it were a single prediction (including rounding), then blend
these single predictions together.
This is already an approximation to the "real" prediction, even
in the non-jnt-comp case. So we leave that code path as-is.
BUG=aomedia:1289
Change-Id: I9ad1fbcb3e12db2b5fc3c82b407f0fd9e6b39750
diff --git a/av1/common/convolve.c b/av1/common/convolve.c
index 7de1074..4b8e39a 100644
--- a/av1/common/convolve.c
+++ b/av1/common/convolve.c
@@ -702,9 +702,7 @@
if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) {
dst[y * dst_stride + x] += res * conv_params->bck_offset;
-
- dst[y * dst_stride + x] = ROUND_POWER_OF_TWO(dst[y * dst_stride + x],
- DIST_PRECISION_BITS - 1);
+ dst[y * dst_stride + x] >>= (DIST_PRECISION_BITS - 1);
} else {
dst[y * dst_stride + x] = res * conv_params->fwd_offset;
}
@@ -742,9 +740,7 @@
if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) {
dst[y * dst_stride + x] += res * conv_params->bck_offset;
-
- dst[y * dst_stride + x] = ROUND_POWER_OF_TWO(dst[y * dst_stride + x],
- DIST_PRECISION_BITS - 1);
+ dst[y * dst_stride + x] >>= (DIST_PRECISION_BITS - 1);
} else {
dst[y * dst_stride + x] = res * conv_params->fwd_offset;
}
@@ -818,9 +814,7 @@
if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) {
dst[y * dst_stride + x] += res * conv_params->bck_offset;
-
- dst[y * dst_stride + x] = ROUND_POWER_OF_TWO(dst[y * dst_stride + x],
- DIST_PRECISION_BITS - 1);
+ dst[y * dst_stride + x] >>= (DIST_PRECISION_BITS - 1);
} else {
dst[y * dst_stride + x] = res * conv_params->fwd_offset;
}
@@ -985,9 +979,7 @@
if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) {
dst[y * dst_stride + x] += res * conv_params->bck_offset;
-
- dst[y * dst_stride + x] = ROUND_POWER_OF_TWO(dst[y * dst_stride + x],
- DIST_PRECISION_BITS - 1);
+ dst[y * dst_stride + x] >>= (DIST_PRECISION_BITS - 1);
} else {
dst[y * dst_stride + x] = res * conv_params->fwd_offset;
}
@@ -1060,9 +1052,7 @@
if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) {
dst[y * dst_stride + x] += res * conv_params->bck_offset;
-
- dst[y * dst_stride + x] = ROUND_POWER_OF_TWO(dst[y * dst_stride + x],
- DIST_PRECISION_BITS - 1);
+ dst[y * dst_stride + x] >>= (DIST_PRECISION_BITS - 1);
} else {
dst[y * dst_stride + x] = res * conv_params->fwd_offset;
}
diff --git a/av1/common/warped_motion.c b/av1/common/warped_motion.c
index f36b8be..63fbc27 100644
--- a/av1/common/warped_motion.c
+++ b/av1/common/warped_motion.c
@@ -516,7 +516,7 @@
if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) {
*p += sum * conv_params->bck_offset;
- *p = ROUND_POWER_OF_TWO(*p, DIST_PRECISION_BITS - 1);
+ *p >>= (DIST_PRECISION_BITS - 1);
} else {
*p = sum * conv_params->fwd_offset;
}
@@ -820,7 +820,7 @@
if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) {
*p += sum * conv_params->bck_offset;
- *p = ROUND_POWER_OF_TWO(*p, DIST_PRECISION_BITS - 1);
+ *p >>= (DIST_PRECISION_BITS - 1);
} else {
*p = sum * conv_params->fwd_offset;
}
diff --git a/av1/common/x86/av1_convolve_scale_sse4.c b/av1/common/x86/av1_convolve_scale_sse4.c
index 3959b06..1c27200 100644
--- a/av1/common/x86/av1_convolve_scale_sse4.c
+++ b/av1/common/x86/av1_convolve_scale_sse4.c
@@ -263,7 +263,6 @@
#if CONFIG_JNT_COMP
const __m128i fwd_offset = _mm_set1_epi32(conv_params->fwd_offset);
const __m128i bck_offset = _mm_set1_epi32(conv_params->bck_offset);
- const __m128i jnt_round = _mm_set1_epi32(1 << (DIST_PRECISION_BITS - 2));
#endif // CONFIG_JNT_COMP
int y_qn = subpel_y_qn;
@@ -315,11 +314,10 @@
__m128i result;
if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) {
- result = _mm_srai_epi32(
- _mm_add_epi32(_mm_add_epi32(_mm_loadu_si128((__m128i *)dst_x),
- _mm_mullo_epi32(subbed, bck_offset)),
- jnt_round),
- DIST_PRECISION_BITS - 1);
+ result =
+ _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128((__m128i *)dst_x),
+ _mm_mullo_epi32(subbed, bck_offset)),
+ DIST_PRECISION_BITS - 1);
} else {
result = _mm_mullo_epi32(subbed, fwd_offset);
}
@@ -347,8 +345,7 @@
if (conv_params->do_average) {
dst[y * dst_stride + x] += res * conv_params->bck_offset;
- dst[y * dst_stride + x] = ROUND_POWER_OF_TWO(dst[y * dst_stride + x],
- DIST_PRECISION_BITS - 1);
+ dst[y * dst_stride + x] >>= (DIST_PRECISION_BITS - 1);
} else {
dst[y * dst_stride + x] = res * conv_params->fwd_offset;
}
@@ -385,7 +382,6 @@
#if CONFIG_JNT_COMP
const __m128i fwd_offset = _mm_set1_epi32(conv_params->fwd_offset);
const __m128i bck_offset = _mm_set1_epi32(conv_params->bck_offset);
- const __m128i jnt_round = _mm_set1_epi32(1 << (DIST_PRECISION_BITS - 2));
#endif // CONFIG_JNT_COMP
int y_qn = subpel_y_qn;
@@ -434,11 +430,10 @@
__m128i result;
if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) {
- result = _mm_srai_epi32(
- _mm_add_epi32(_mm_add_epi32(_mm_loadu_si128((__m128i *)dst_x),
- _mm_mullo_epi32(subbed, bck_offset)),
- jnt_round),
- DIST_PRECISION_BITS - 1);
+ result =
+ _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128((__m128i *)dst_x),
+ _mm_mullo_epi32(subbed, bck_offset)),
+ DIST_PRECISION_BITS - 1);
} else {
result = _mm_mullo_epi32(subbed, fwd_offset);
}
@@ -466,8 +461,7 @@
if (conv_params->do_average) {
dst[y * dst_stride + x] += res * conv_params->bck_offset;
- dst[y * dst_stride + x] = ROUND_POWER_OF_TWO(dst[y * dst_stride + x],
- DIST_PRECISION_BITS - 1);
+ dst[y * dst_stride + x] >>= (DIST_PRECISION_BITS - 1);
} else {
dst[y * dst_stride + x] = res * conv_params->fwd_offset;
}
diff --git a/av1/common/x86/convolve_2d_sse2.c b/av1/common/x86/convolve_2d_sse2.c
index c327735..111bc85 100644
--- a/av1/common/x86/convolve_2d_sse2.c
+++ b/av1/common/x86/convolve_2d_sse2.c
@@ -676,8 +676,6 @@
const int w1 = conv_params->bck_offset;
const __m128i wt0 = _mm_set1_epi32(w0);
const __m128i wt1 = _mm_set1_epi32(w1);
- const int jnt_round_const = 1 << (DIST_PRECISION_BITS - 2);
- const __m128i jnt_r = _mm_set1_epi32(jnt_round_const);
if (!(w % 16)) {
for (i = 0; i < h; ++i) {
@@ -697,26 +695,22 @@
__m128i mul = _mm_mullo_epi16(d32_0, wt1);
__m128i weighted_res = _mm_sll_epi32(mul, left_shift);
__m128i sum = _mm_add_epi32(_mm_loadu_si128(p + 0), weighted_res);
- d32_0 = _mm_srai_epi32(_mm_add_epi32(sum, jnt_r),
- DIST_PRECISION_BITS - 1);
+ d32_0 = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
mul = _mm_mullo_epi16(d32_1, wt1);
weighted_res = _mm_sll_epi32(mul, left_shift);
sum = _mm_add_epi32(_mm_loadu_si128(p + 1), weighted_res);
- d32_1 = _mm_srai_epi32(_mm_add_epi32(sum, jnt_r),
- DIST_PRECISION_BITS - 1);
+ d32_1 = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
mul = _mm_mullo_epi16(d32_2, wt1);
weighted_res = _mm_sll_epi32(mul, left_shift);
sum = _mm_add_epi32(_mm_loadu_si128(p + 2), weighted_res);
- d32_2 = _mm_srai_epi32(_mm_add_epi32(sum, jnt_r),
- DIST_PRECISION_BITS - 1);
+ d32_2 = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
mul = _mm_mullo_epi16(d32_3, wt1);
weighted_res = _mm_sll_epi32(mul, left_shift);
sum = _mm_add_epi32(_mm_loadu_si128(p + 3), weighted_res);
- d32_3 = _mm_srai_epi32(_mm_add_epi32(sum, jnt_r),
- DIST_PRECISION_BITS - 1);
+ d32_3 = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
} else {
d32_0 = _mm_sll_epi32(_mm_mullo_epi16(d32_0, wt0), left_shift);
d32_1 = _mm_sll_epi32(_mm_mullo_epi16(d32_1, wt0), left_shift);
@@ -763,14 +757,12 @@
__m128i mul = _mm_mullo_epi16(d32_0, wt1);
__m128i weighted_res = _mm_sll_epi32(mul, left_shift);
__m128i sum = _mm_add_epi32(_mm_loadu_si128(p + 0), weighted_res);
- d32_0 = _mm_srai_epi32(_mm_add_epi32(sum, jnt_r),
- DIST_PRECISION_BITS - 1);
+ d32_0 = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
mul = _mm_mullo_epi16(d32_1, wt1);
weighted_res = _mm_sll_epi32(mul, left_shift);
sum = _mm_add_epi32(_mm_loadu_si128(p + 1), weighted_res);
- d32_1 = _mm_srai_epi32(_mm_add_epi32(sum, jnt_r),
- DIST_PRECISION_BITS - 1);
+ d32_1 = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
} else {
d32_0 = _mm_sll_epi32(_mm_mullo_epi16(d32_0, wt0), left_shift);
d32_1 = _mm_sll_epi32(_mm_mullo_epi16(d32_1, wt0), left_shift);
@@ -806,8 +798,7 @@
__m128i mul = _mm_mullo_epi16(d32_0, wt1);
__m128i weighted_res = _mm_sll_epi32(mul, left_shift);
__m128i sum = _mm_add_epi32(_mm_loadu_si128(p + 0), weighted_res);
- d32_0 = _mm_srai_epi32(_mm_add_epi32(sum, jnt_r),
- DIST_PRECISION_BITS - 1);
+ d32_0 = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
} else {
d32_0 = _mm_sll_epi32(_mm_mullo_epi16(d32_0, wt0), left_shift);
}
@@ -838,8 +829,7 @@
__m128i mul = _mm_mullo_epi16(d32_0, wt1);
__m128i weighted_res = _mm_sll_epi32(mul, left_shift);
__m128i sum = _mm_add_epi32(_mm_loadl_epi64(p), weighted_res);
- d32_0 = _mm_srai_epi32(_mm_add_epi32(sum, jnt_r),
- DIST_PRECISION_BITS - 1);
+ d32_0 = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
} else {
d32_0 = _mm_sll_epi32(_mm_mullo_epi16(d32_0, wt0), left_shift);
}
diff --git a/av1/common/x86/convolve_2d_sse4.c b/av1/common/x86/convolve_2d_sse4.c
index 8e5e286..123350a 100644
--- a/av1/common/x86/convolve_2d_sse4.c
+++ b/av1/common/x86/convolve_2d_sse4.c
@@ -47,9 +47,6 @@
const int w1 = conv_params->bck_offset;
const __m128i wt0 = _mm_set_epi32(w0, w0, w0, w0);
const __m128i wt1 = _mm_set_epi32(w1, w1, w1, w1);
- const int jnt_round_const = 1 << (DIST_PRECISION_BITS - 2);
- const __m128i jnt_r = _mm_set_epi32(jnt_round_const, jnt_round_const,
- jnt_round_const, jnt_round_const);
/* Horizontal filter */
{
@@ -207,18 +204,14 @@
if (do_average) {
_mm_storeu_si128(
p + 0, _mm_srai_epi32(
- _mm_add_epi32(_mm_add_epi32(_mm_loadu_si128(p + 0),
- _mm_mullo_epi32(
- res_lo_round, wt1)),
- jnt_r),
+ _mm_add_epi32(_mm_loadu_si128(p + 0),
+ _mm_mullo_epi32(res_lo_round, wt1)),
DIST_PRECISION_BITS - 1));
_mm_storeu_si128(
p + 1, _mm_srai_epi32(
- _mm_add_epi32(_mm_add_epi32(_mm_loadu_si128(p + 1),
- _mm_mullo_epi32(
- res_hi_round, wt1)),
- jnt_r),
+ _mm_add_epi32(_mm_loadu_si128(p + 1),
+ _mm_mullo_epi32(res_hi_round, wt1)),
DIST_PRECISION_BITS - 1));
} else {
_mm_storeu_si128(p + 0, _mm_mullo_epi32(res_lo_round, wt0));
diff --git a/av1/common/x86/highbd_convolve_2d_sse4.c b/av1/common/x86/highbd_convolve_2d_sse4.c
index e23e0d0..853ecfd 100644
--- a/av1/common/x86/highbd_convolve_2d_sse4.c
+++ b/av1/common/x86/highbd_convolve_2d_sse4.c
@@ -39,8 +39,6 @@
const int w1 = conv_params->bck_offset;
const __m128i wt0 = _mm_set1_epi32(w0);
const __m128i wt1 = _mm_set1_epi32(w1);
- const int jnt_round_const = 1 << (DIST_PRECISION_BITS - 2);
- const __m128i jnt_r = _mm_set1_epi32(jnt_round_const);
// Check that, even with 12-bit input, the intermediate values will fit
// into an unsigned 15-bit intermediate array.
@@ -202,12 +200,10 @@
_mm_loadu_si128(p + 0), _mm_mullo_epi32(res_lo_round, wt1));
const __m128i jnt_sum_hi = _mm_add_epi32(
_mm_loadu_si128(p + 1), _mm_mullo_epi32(res_hi_round, wt1));
- const __m128i jnt_round_res_lo = _mm_add_epi32(jnt_sum_lo, jnt_r);
- const __m128i jnt_round_res_hi = _mm_add_epi32(jnt_sum_hi, jnt_r);
const __m128i final_lo =
- _mm_srai_epi32(jnt_round_res_lo, DIST_PRECISION_BITS - 1);
+ _mm_srai_epi32(jnt_sum_lo, DIST_PRECISION_BITS - 1);
const __m128i final_hi =
- _mm_srai_epi32(jnt_round_res_hi, DIST_PRECISION_BITS - 1);
+ _mm_srai_epi32(jnt_sum_hi, DIST_PRECISION_BITS - 1);
_mm_storeu_si128(p + 0, final_lo);
_mm_storeu_si128(p + 1, final_hi);
diff --git a/av1/common/x86/highbd_warp_plane_sse4.c b/av1/common/x86/highbd_warp_plane_sse4.c
index d40a969..0d4d17f 100644
--- a/av1/common/x86/highbd_warp_plane_sse4.c
+++ b/av1/common/x86/highbd_warp_plane_sse4.c
@@ -42,8 +42,6 @@
const int w1 = conv_params->bck_offset;
const __m128i wt0 = _mm_set1_epi32(w0);
const __m128i wt1 = _mm_set1_epi32(w1);
- const int jnt_round_const = 1 << (DIST_PRECISION_BITS - 2);
- const __m128i jnt_r = _mm_set1_epi32(jnt_round_const);
#endif // CONFIG_JNT_COMP
/* Note: For this code to work, the left/right frame borders need to be
@@ -320,8 +318,7 @@
if (comp_avg) {
const __m128i sum = _mm_add_epi32(_mm_loadu_si128(p),
_mm_mullo_epi32(res_lo, wt1));
- const __m128i sum_round = _mm_add_epi32(sum, jnt_r);
- res_lo = _mm_srai_epi32(sum_round, DIST_PRECISION_BITS - 1);
+ res_lo = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
} else {
res_lo = _mm_mullo_epi32(res_lo, wt0);
}
@@ -345,8 +342,7 @@
if (comp_avg) {
const __m128i sum = _mm_add_epi32(_mm_loadu_si128(p + 1),
_mm_mullo_epi32(res_hi, wt1));
- const __m128i sum_round = _mm_add_epi32(sum, jnt_r);
- res_hi = _mm_srai_epi32(sum_round, DIST_PRECISION_BITS - 1);
+ res_hi = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
} else {
res_hi = _mm_mullo_epi32(res_hi, wt0);
}
diff --git a/av1/common/x86/warp_plane_sse4.c b/av1/common/x86/warp_plane_sse4.c
index e0d6206..94b5f09 100644
--- a/av1/common/x86/warp_plane_sse4.c
+++ b/av1/common/x86/warp_plane_sse4.c
@@ -39,8 +39,6 @@
const int w1 = conv_params->bck_offset;
const __m128i wt0 = _mm_set1_epi32(w0);
const __m128i wt1 = _mm_set1_epi32(w1);
- const int jnt_round_const = 1 << (DIST_PRECISION_BITS - 2);
- const __m128i jnt_r = _mm_set1_epi32(jnt_round_const);
#endif // CONFIG_JNT_COMP
/* Note: For this code to work, the left/right frame borders need to be
@@ -317,8 +315,7 @@
if (comp_avg) {
const __m128i sum = _mm_add_epi32(_mm_loadu_si128(p),
_mm_mullo_epi32(res_lo, wt1));
- const __m128i sum_round = _mm_add_epi32(sum, jnt_r);
- res_lo = _mm_srai_epi32(sum_round, DIST_PRECISION_BITS - 1);
+ res_lo = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
} else {
res_lo = _mm_mullo_epi32(res_lo, wt0);
}
@@ -340,8 +337,7 @@
if (comp_avg) {
const __m128i sum = _mm_add_epi32(_mm_loadu_si128(p + 1),
_mm_mullo_epi32(res_hi, wt1));
- const __m128i sum_round = _mm_add_epi32(sum, jnt_r);
- res_hi = _mm_srai_epi32(sum_round, DIST_PRECISION_BITS - 1);
+ res_hi = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
} else {
res_hi = _mm_mullo_epi32(res_hi, wt0);
}