Round compound prediction into 16 bits
Apply shifts for compound prediction immediately when sum
or weighted sum (for jnt_comp) is computed. Such that all
intermediate results can be fit into 16 bits.
Note: now the buffer is still 32 bits. We need new simd functions
for 16 bits and finally reduce buffer to 16 bits.
Change-Id: Ia46a4736d69aa028623dfb9f036a6ce527e5cd9f
diff --git a/aom_dsp/x86/convolve_avx2.h b/aom_dsp/x86/convolve_avx2.h
index ec5868e..3ca424c 100644
--- a/aom_dsp/x86/convolve_avx2.h
+++ b/aom_dsp/x86/convolve_avx2.h
@@ -122,11 +122,12 @@
static INLINE void add_store_aligned(CONV_BUF_TYPE *const dst,
const __m256i *const res,
- const __m256i *const avg_mask) {
+ const __m256i *const avg_mask, int shift) {
__m256i d;
d = _mm256_load_si256((__m256i *)dst);
d = _mm256_and_si256(d, *avg_mask);
d = _mm256_add_epi32(d, *res);
+ if (shift) d = _mm256_srai_epi32(d, 1);
_mm256_store_si256((__m256i *)dst, d);
}
diff --git a/av1/common/convolve.c b/av1/common/convolve.c
index fef179c..ba9b96f 100644
--- a/av1/common/convolve.c
+++ b/av1/common/convolve.c
@@ -424,10 +424,13 @@
CONV_BUF_TYPE res = ROUND_POWER_OF_TWO(sum, conv_params->round_1) -
((1 << (offset_bits - conv_params->round_1)) +
(1 << (offset_bits - conv_params->round_1 - 1)));
- if (conv_params->do_average)
- dst[y * dst_stride + x] += res;
- else
+ if (conv_params->do_average) {
+ int32_t tmp = dst[y * dst_stride + x];
+ tmp += res;
+ dst[y * dst_stride + x] = tmp >> 1;
+ } else {
dst[y * dst_stride + x] = res;
+ }
}
}
}
@@ -460,10 +463,13 @@
}
res *= (1 << bits);
res = ROUND_POWER_OF_TWO(res, conv_params->round_1);
- if (conv_params->do_average)
- dst[y * dst_stride + x] += res;
- else
+ if (conv_params->do_average) {
+ int32_t tmp = dst[y * dst_stride + x];
+ tmp += res;
+ dst[y * dst_stride + x] = tmp >> 1;
+ } else {
dst[y * dst_stride + x] = res;
+ }
}
}
}
@@ -495,10 +501,13 @@
res += x_filter[k] * src[y * src_stride + x - fo_horiz + k];
}
res = (1 << bits) * ROUND_POWER_OF_TWO(res, conv_params->round_0);
- if (conv_params->do_average)
- dst[y * dst_stride + x] += res;
- else
+ if (conv_params->do_average) {
+ int32_t tmp = dst[y * dst_stride + x];
+ tmp += res;
+ dst[y * dst_stride + x] = tmp >> 1;
+ } else {
dst[y * dst_stride + x] = res;
+ }
}
}
}
@@ -524,10 +533,13 @@
for (int y = 0; y < h; ++y) {
for (int x = 0; x < w; ++x) {
CONV_BUF_TYPE res = src[y * src_stride + x] << bits;
- if (conv_params->do_average)
- dst[y * dst_stride + x] += res;
- else
+ if (conv_params->do_average) {
+ int32_t tmp = dst[y * dst_stride + x];
+ tmp += res;
+ dst[y * dst_stride + x] = tmp >> 1;
+ } else {
dst[y * dst_stride + x] = res;
+ }
}
}
}
@@ -714,15 +726,20 @@
(1 << (offset_bits - conv_params->round_1 - 1)));
if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) {
- dst[y * dst_stride + x] += res * conv_params->bck_offset;
+ int32_t tmp = dst[y * dst_stride + x];
+ tmp = tmp * conv_params->fwd_offset + res * conv_params->bck_offset;
+ dst[y * dst_stride + x] = tmp >> DIST_PRECISION_BITS;
} else {
- dst[y * dst_stride + x] = res * conv_params->fwd_offset;
+ dst[y * dst_stride + x] = res;
}
} else {
- if (conv_params->do_average)
- dst[y * dst_stride + x] += res;
- else
+ if (conv_params->do_average) {
+ int32_t tmp = dst[y * dst_stride + x];
+ tmp += res;
+ dst[y * dst_stride + x] = tmp >> 1;
+ } else {
dst[y * dst_stride + x] = res;
+ }
}
}
}
@@ -756,15 +773,20 @@
res = ROUND_POWER_OF_TWO(res, conv_params->round_1);
if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) {
- dst[y * dst_stride + x] += res * conv_params->bck_offset;
+ int32_t tmp = dst[y * dst_stride + x];
+ tmp = tmp * conv_params->fwd_offset + res * conv_params->bck_offset;
+ dst[y * dst_stride + x] = tmp >> DIST_PRECISION_BITS;
} else {
- dst[y * dst_stride + x] = res * conv_params->fwd_offset;
+ dst[y * dst_stride + x] = res;
}
} else {
- if (conv_params->do_average)
- dst[y * dst_stride + x] += res;
- else
+ if (conv_params->do_average) {
+ int32_t tmp = dst[y * dst_stride + x];
+ tmp += res;
+ dst[y * dst_stride + x] = tmp >> 1;
+ } else {
dst[y * dst_stride + x] = res;
+ }
}
}
}
@@ -797,15 +819,20 @@
res = (1 << bits) * ROUND_POWER_OF_TWO(res, conv_params->round_0);
if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) {
- dst[y * dst_stride + x] += res * conv_params->bck_offset;
+ int32_t tmp = dst[y * dst_stride + x];
+ tmp = tmp * conv_params->fwd_offset + res * conv_params->bck_offset;
+ dst[y * dst_stride + x] = tmp >> DIST_PRECISION_BITS;
} else {
- dst[y * dst_stride + x] = res * conv_params->fwd_offset;
+ dst[y * dst_stride + x] = res;
}
} else {
- if (conv_params->do_average)
- dst[y * dst_stride + x] += res;
- else
+ if (conv_params->do_average) {
+ int32_t tmp = dst[y * dst_stride + x];
+ tmp += res;
+ dst[y * dst_stride + x] = tmp >> 1;
+ } else {
dst[y * dst_stride + x] = res;
+ }
}
}
}
@@ -834,15 +861,20 @@
CONV_BUF_TYPE res = src[y * src_stride + x] << bits;
if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) {
- dst[y * dst_stride + x] += res * conv_params->bck_offset;
+ int32_t tmp = dst[y * dst_stride + x];
+ tmp = tmp * conv_params->fwd_offset + res * conv_params->bck_offset;
+ dst[y * dst_stride + x] = tmp >> DIST_PRECISION_BITS;
} else {
- dst[y * dst_stride + x] = res * conv_params->fwd_offset;
+ dst[y * dst_stride + x] = res;
}
} else {
- if (conv_params->do_average)
- dst[y * dst_stride + x] += res;
- else
+ if (conv_params->do_average) {
+ int32_t tmp = dst[y * dst_stride + x];
+ tmp += res;
+ dst[y * dst_stride + x] = tmp >> 1;
+ } else {
dst[y * dst_stride + x] = res;
+ }
}
}
}
@@ -907,21 +939,29 @@
#if CONFIG_JNT_COMP
if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) {
- dst[y * dst_stride + x] += res * conv_params->bck_offset;
+ int32_t tmp = dst[y * dst_stride + x];
+ tmp = tmp * conv_params->fwd_offset + res * conv_params->bck_offset;
+ dst[y * dst_stride + x] = tmp >> DIST_PRECISION_BITS;
} else {
- dst[y * dst_stride + x] = res * conv_params->fwd_offset;
+ dst[y * dst_stride + x] = res;
}
} else {
- if (conv_params->do_average)
- dst[y * dst_stride + x] += res;
- else
+ if (conv_params->do_average) {
+ int32_t tmp = dst[y * dst_stride + x];
+ tmp += res;
+ dst[y * dst_stride + x] = tmp >> 1;
+ } else {
dst[y * dst_stride + x] = res;
+ }
}
#else
- if (conv_params->do_average)
- dst[y * dst_stride + x] += res;
- else
+ if (conv_params->do_average) {
+ int32_t tmp = dst[y * dst_stride + x];
+ tmp += res;
+ dst[y * dst_stride + x] = tmp >> 1;
+ } else {
dst[y * dst_stride + x] = res;
+ }
#endif // CONFIG_JNT_COMP
}
src_vert++;
@@ -1044,10 +1084,13 @@
CONV_BUF_TYPE res = ROUND_POWER_OF_TWO(sum, conv_params->round_1) -
((1 << (offset_bits - conv_params->round_1)) +
(1 << (offset_bits - conv_params->round_1 - 1)));
- if (conv_params->do_average)
- dst[y * dst_stride + x] += res;
- else
+ if (conv_params->do_average) {
+ int32_t tmp = dst[y * dst_stride + x];
+ tmp += res;
+ dst[y * dst_stride + x] = tmp >> 1;
+ } else {
dst[y * dst_stride + x] = res;
+ }
}
}
}
@@ -1339,15 +1382,20 @@
(1 << (offset_bits - conv_params->round_1 - 1)));
if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) {
- dst[y * dst_stride + x] += res * conv_params->bck_offset;
+ int32_t tmp = dst[y * dst_stride + x];
+ tmp = tmp * conv_params->fwd_offset + res * conv_params->bck_offset;
+ dst[y * dst_stride + x] = tmp >> DIST_PRECISION_BITS;
} else {
- dst[y * dst_stride + x] = res * conv_params->fwd_offset;
+ dst[y * dst_stride + x] = res;
}
} else {
- if (conv_params->do_average)
- dst[y * dst_stride + x] += res;
- else
+ if (conv_params->do_average) {
+ int32_t tmp = dst[y * dst_stride + x];
+ tmp += res;
+ dst[y * dst_stride + x] = tmp >> 1;
+ } else {
dst[y * dst_stride + x] = res;
+ }
}
}
}
@@ -1533,21 +1581,29 @@
#if CONFIG_JNT_COMP
if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) {
- dst[y * dst_stride + x] += res * conv_params->bck_offset;
+ int32_t tmp = dst[y * dst_stride + x];
+ tmp = tmp * conv_params->fwd_offset + res * conv_params->bck_offset;
+ dst[y * dst_stride + x] = tmp >> DIST_PRECISION_BITS;
} else {
- dst[y * dst_stride + x] = res * conv_params->fwd_offset;
+ dst[y * dst_stride + x] = res;
}
} else {
- if (conv_params->do_average)
- dst[y * dst_stride + x] += res;
- else
+ if (conv_params->do_average) {
+ int32_t tmp = dst[y * dst_stride + x];
+ tmp += res;
+ dst[y * dst_stride + x] = tmp >> 1;
+ } else {
dst[y * dst_stride + x] = res;
+ }
}
#else
- if (conv_params->do_average)
- dst[y * dst_stride + x] += res;
- else
+ if (conv_params->do_average) {
+ int32_t tmp = dst[y * dst_stride + x];
+ tmp += res;
+ dst[y * dst_stride + x] = tmp >> 1;
+ } else {
dst[y * dst_stride + x] = res;
+ }
#endif // CONFIG_JNT_COMP
}
src_vert++;
diff --git a/av1/common/reconinter.c b/av1/common/reconinter.c
index 18e316a..d975aa8 100644
--- a/av1/common/reconinter.c
+++ b/av1/common/reconinter.c
@@ -30,17 +30,9 @@
// prediction.
static INLINE int get_compound_post_rounding_bits(
- const MB_MODE_INFO *const mbmi, const ConvolveParams *conv_params) {
+ const ConvolveParams *conv_params) {
assert(conv_params->is_compound);
- int round_bits =
- 2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
- if (is_masked_compound_type(mbmi->interinter_compound_type))
- return round_bits;
- round_bits += conv_params->is_compound;
-#if CONFIG_JNT_COMP
- if (conv_params->use_jnt_comp_avg) round_bits += DIST_PRECISION_BITS - 1;
-#endif // CONFIG_JNT_COMP
- return round_bits;
+ return 2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
}
static INLINE int allow_warp(const MODE_INFO *const mi,
@@ -1105,8 +1097,7 @@
if (conv_params.is_compound) {
assert(conv_params.dst != NULL);
- int round_bits =
- get_compound_post_rounding_bits(&mi->mbmi, &conv_params);
+ int round_bits = get_compound_post_rounding_bits(&conv_params);
if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
av1_highbd_convolve_rounding(tmp_dst, tmp_dst_stride, dst,
dst_buf->stride, b4_w, b4_h,
@@ -1242,7 +1233,7 @@
// TODO(angiebird): This part needs optimization
if (conv_params.is_compound) {
assert(conv_params.dst != NULL);
- int round_bits = get_compound_post_rounding_bits(&mi->mbmi, &conv_params);
+ int round_bits = get_compound_post_rounding_bits(&conv_params);
if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
av1_highbd_convolve_rounding(tmp_dst, MAX_SB_SIZE, dst, dst_buf->stride,
w, h, round_bits, xd->bd);
diff --git a/av1/common/warped_motion.c b/av1/common/warped_motion.c
index fc4d8da..fa6f7ec 100644
--- a/av1/common/warped_motion.c
+++ b/av1/common/warped_motion.c
@@ -508,21 +508,30 @@
#if CONFIG_JNT_COMP
if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) {
- *p += sum * conv_params->bck_offset;
+ int32_t tmp32 = *p;
+ tmp32 = tmp32 * conv_params->fwd_offset +
+ sum * conv_params->bck_offset;
+ *p = tmp32 >> DIST_PRECISION_BITS;
} else {
- *p = sum * conv_params->fwd_offset;
+ *p = sum;
}
} else {
- if (conv_params->do_average)
- *p += sum;
- else
+ if (conv_params->do_average) {
+ int32_t tmp32 = *p;
+ tmp32 += sum;
+ *p = tmp32 >> 1;
+ } else {
*p = sum;
+ }
}
#else
- if (conv_params->do_average)
- *p += sum;
- else
+ if (conv_params->do_average) {
+ int32_t tmp32 = *p;
+ tmp32 += sum;
+ *p = tmp32 >> 1;
+ } else {
*p = sum;
+ }
#endif // CONFIG_JNT_COMP
} else {
uint16_t *p =
@@ -802,21 +811,30 @@
#if CONFIG_JNT_COMP
if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) {
- *p += sum * conv_params->bck_offset;
+ int32_t tmp32 = *p;
+ tmp32 = tmp32 * conv_params->fwd_offset +
+ sum * conv_params->bck_offset;
+ *p = tmp32 >> DIST_PRECISION_BITS;
} else {
- *p = sum * conv_params->fwd_offset;
+ *p = sum;
}
} else {
- if (conv_params->do_average)
- *p += sum;
- else
+ if (conv_params->do_average) {
+ int32_t tmp32 = *p;
+ tmp32 += sum;
+ *p = tmp32 >> 1;
+ } else {
*p = sum;
+ }
}
#else
- if (conv_params->do_average)
- *p += sum;
- else
+ if (conv_params->do_average) {
+ int32_t tmp32 = *p;
+ tmp32 += sum;
+ *p = tmp32 >> 1;
+ } else {
*p = sum;
+ }
#endif // CONFIG_JNT_COMP
} else {
uint8_t *p =
diff --git a/av1/common/x86/av1_convolve_scale_sse4.c b/av1/common/x86/av1_convolve_scale_sse4.c
index e53717d..973dbe3 100644
--- a/av1/common/x86/av1_convolve_scale_sse4.c
+++ b/av1/common/x86/av1_convolve_scale_sse4.c
@@ -314,20 +314,26 @@
__m128i result;
if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) {
- result = _mm_add_epi32(_mm_loadu_si128((__m128i *)dst_x),
- _mm_mullo_epi32(subbed, bck_offset));
+ __m128i tmp = _mm_loadu_si128((__m128i *)dst_x);
+ tmp = _mm_add_epi32(_mm_mullo_epi32(tmp, fwd_offset),
+ _mm_mullo_epi32(subbed, bck_offset));
+ result = _mm_srai_epi32(tmp, DIST_PRECISION_BITS);
} else {
- result = _mm_mullo_epi32(subbed, fwd_offset);
+ result = subbed;
}
} else {
- result = (conv_params->do_average)
- ? _mm_add_epi32(subbed, _mm_loadu_si128((__m128i *)dst_x))
- : subbed;
+ result =
+ (conv_params->do_average)
+ ? _mm_srai_epi32(
+ _mm_add_epi32(subbed, _mm_loadu_si128((__m128i *)dst_x)),
+ 1)
+ : subbed;
}
#else
const __m128i result =
(conv_params->do_average)
- ? _mm_add_epi32(subbed, _mm_loadu_si128((__m128i *)dst_x))
+ ? _mm_srai_epi32(
+ _mm_add_epi32(subbed, _mm_loadu_si128((__m128i *)dst_x)), 1)
: subbed;
#endif // CONFIG_JNT_COMP
@@ -341,16 +347,21 @@
#if CONFIG_JNT_COMP
if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) {
- dst[y * dst_stride + x] += res * conv_params->bck_offset;
+ int32_t tmp = dst[y * dst_stride + x];
+ tmp = tmp * conv_params->fwd_offset + res * conv_params->bck_offset;
+ dst[y * dst_stride + x] = tmp >> DIST_PRECISION_BITS;
} else {
- dst[y * dst_stride + x] = res * conv_params->fwd_offset;
+ dst[y * dst_stride + x] = res;
}
} else {
#endif // CONFIG_JNT_COMP
- if (conv_params->do_average)
- dst[y * dst_stride + x] += res;
- else
+ if (conv_params->do_average) {
+ int32_t tmp = dst[y * dst_stride + x];
+ tmp += res;
+ dst[y * dst_stride + x] = tmp >> 1;
+ } else {
dst[y * dst_stride + x] = res;
+ }
#if CONFIG_JNT_COMP
}
#endif // CONFIG_JNT_COMP
@@ -426,20 +437,26 @@
__m128i result;
if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) {
- result = _mm_add_epi32(_mm_loadu_si128((__m128i *)dst_x),
- _mm_mullo_epi32(subbed, bck_offset));
+ __m128i tmp = _mm_loadu_si128((__m128i *)dst_x);
+ tmp = _mm_add_epi32(_mm_mullo_epi32(tmp, fwd_offset),
+ _mm_mullo_epi32(subbed, bck_offset));
+ result = _mm_srai_epi32(tmp, DIST_PRECISION_BITS);
} else {
- result = _mm_mullo_epi32(subbed, fwd_offset);
+ result = subbed;
}
} else {
- result = (conv_params->do_average)
- ? _mm_add_epi32(subbed, _mm_loadu_si128((__m128i *)dst_x))
- : subbed;
+ result =
+ (conv_params->do_average)
+ ? _mm_srai_epi32(
+ _mm_add_epi32(subbed, _mm_loadu_si128((__m128i *)dst_x)),
+ 1)
+ : subbed;
}
#else
const __m128i result =
(conv_params->do_average)
- ? _mm_add_epi32(subbed, _mm_loadu_si128((__m128i *)dst_x))
+ ? _mm_srai_epi32(
+ _mm_add_epi32(subbed, _mm_loadu_si128((__m128i *)dst_x)), 1)
: subbed;
#endif // CONFIG_JNT_COMP
@@ -453,16 +470,21 @@
#if CONFIG_JNT_COMP
if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) {
- dst[y * dst_stride + x] += res * conv_params->bck_offset;
+ int32_t tmp = dst[y * dst_stride + x];
+ tmp = tmp * conv_params->fwd_offset + res * conv_params->bck_offset;
+ dst[y * dst_stride + x] = tmp >> DIST_PRECISION_BITS;
} else {
- dst[y * dst_stride + x] = res * conv_params->fwd_offset;
+ dst[y * dst_stride + x] = res;
}
} else {
#endif // CONFIG_JNT_COMP
- if (conv_params->do_average)
- dst[y * dst_stride + x] += res;
- else
+ if (conv_params->do_average) {
+ int32_t tmp = dst[y * dst_stride + x];
+ tmp += res;
+ dst[y * dst_stride + x] = tmp >> 1;
+ } else {
dst[y * dst_stride + x] = res;
+ }
#if CONFIG_JNT_COMP
}
#endif // CONFIG_JNT_COMP
diff --git a/av1/common/x86/convolve_2d_avx2.c b/av1/common/x86/convolve_2d_avx2.c
index fafe344..6407c3a 100644
--- a/av1/common/x86/convolve_2d_avx2.c
+++ b/av1/common/x86/convolve_2d_avx2.c
@@ -126,9 +126,10 @@
const __m256i res_bx =
_mm256_permute2x128_si256(res_a_round, res_b_round, 0x31);
- add_store_aligned(&dst[i * dst_stride + j], &res_ax, &avg_mask);
+ add_store_aligned(&dst[i * dst_stride + j], &res_ax, &avg_mask,
+ conv_params->do_average);
add_store_aligned(&dst[i * dst_stride + j + dst_stride], &res_bx,
- &avg_mask);
+ &avg_mask, conv_params->do_average);
} else {
const __m128i res_ax = _mm256_extracti128_si256(res_a_round, 0);
const __m128i res_bx = _mm256_extracti128_si256(res_a_round, 1);
@@ -140,6 +141,10 @@
r1 = _mm_and_si128(r1, _mm256_extracti128_si256(avg_mask, 0));
r0 = _mm_add_epi32(r0, res_ax);
r1 = _mm_add_epi32(r1, res_bx);
+ if (conv_params->do_average) {
+ r0 = _mm_srai_epi32(r0, 1);
+ r1 = _mm_srai_epi32(r1, 1);
+ }
_mm_store_si128((__m128i *)&dst[i * dst_stride + j], r0);
_mm_store_si128((__m128i *)&dst[i * dst_stride + j + dst_stride], r1);
}
diff --git a/av1/common/x86/convolve_2d_sse2.c b/av1/common/x86/convolve_2d_sse2.c
index 96a6042..941d195 100644
--- a/av1/common/x86/convolve_2d_sse2.c
+++ b/av1/common/x86/convolve_2d_sse2.c
@@ -193,10 +193,14 @@
// Accumulate values into the destination buffer
__m128i *const p = (__m128i *)&dst[i * dst_stride + j];
if (do_average) {
- _mm_storeu_si128(p + 0,
- _mm_add_epi32(_mm_loadu_si128(p + 0), res_lo_round));
- _mm_storeu_si128(p + 1,
- _mm_add_epi32(_mm_loadu_si128(p + 1), res_hi_round));
+ _mm_storeu_si128(
+ p + 0,
+ _mm_srai_epi32(
+ _mm_add_epi32(_mm_loadu_si128(p + 0), res_lo_round), 1));
+ _mm_storeu_si128(
+ p + 1,
+ _mm_srai_epi32(
+ _mm_add_epi32(_mm_loadu_si128(p + 1), res_hi_round), 1));
} else {
_mm_storeu_si128(p + 0, res_lo_round);
_mm_storeu_si128(p + 1, res_hi_round);
@@ -444,10 +448,18 @@
__m128i *const p = (__m128i *)&dst[j];
if (do_average) {
- _mm_storeu_si128(p + 0, _mm_add_epi32(_mm_loadu_si128(p + 0), d32_0));
- _mm_storeu_si128(p + 1, _mm_add_epi32(_mm_loadu_si128(p + 1), d32_1));
- _mm_storeu_si128(p + 2, _mm_add_epi32(_mm_loadu_si128(p + 2), d32_2));
- _mm_storeu_si128(p + 3, _mm_add_epi32(_mm_loadu_si128(p + 3), d32_3));
+ _mm_storeu_si128(
+ p + 0,
+ _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p + 0), d32_0), 1));
+ _mm_storeu_si128(
+ p + 1,
+ _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p + 1), d32_1), 1));
+ _mm_storeu_si128(
+ p + 2,
+ _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p + 2), d32_2), 1));
+ _mm_storeu_si128(
+ p + 3,
+ _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p + 3), d32_3), 1));
} else {
_mm_storeu_si128(p + 0, d32_0);
_mm_storeu_si128(p + 1, d32_1);
@@ -471,8 +483,12 @@
__m128i *const p = (__m128i *)&dst[j];
if (do_average) {
- _mm_storeu_si128(p + 0, _mm_add_epi32(_mm_loadu_si128(p + 0), d32_0));
- _mm_storeu_si128(p + 1, _mm_add_epi32(_mm_loadu_si128(p + 1), d32_1));
+ _mm_storeu_si128(
+ p + 0,
+ _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p + 0), d32_0), 1));
+ _mm_storeu_si128(
+ p + 1,
+ _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p + 1), d32_1), 1));
} else {
_mm_storeu_si128(p + 0, d32_0);
_mm_storeu_si128(p + 1, d32_1);
@@ -491,7 +507,8 @@
d32_0 = _mm_sll_epi32(d32_0, left_shift);
__m128i *const p = (__m128i *)&dst[j];
if (do_average) {
- _mm_storeu_si128(p, _mm_add_epi32(_mm_loadu_si128(p), d32_0));
+ _mm_storeu_si128(
+ p, _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p), d32_0), 1));
} else {
_mm_storeu_si128(p, d32_0);
}
@@ -509,7 +526,8 @@
d32_0 = _mm_sll_epi32(d32_0, left_shift);
__m128i *const p = (__m128i *)&dst[j];
if (do_average) {
- _mm_storel_epi64(p, _mm_add_epi32(_mm_loadl_epi64(p), d32_0));
+ _mm_storel_epi64(
+ p, _mm_srai_epi32(_mm_add_epi32(_mm_loadl_epi64(p), d32_0), 1));
} else {
_mm_storel_epi64(p, d32_0);
}
@@ -707,39 +725,52 @@
if (do_average) {
__m128i mul = _mm_mullo_epi16(d32_0, wt1);
__m128i weighted_res = _mm_sll_epi32(mul, left_shift);
- __m128i sum = _mm_add_epi32(_mm_loadu_si128(p + 0), weighted_res);
- d32_0 = sum;
+ __m128i tmp = _mm_loadu_si128(p + 0);
+ __m128i sum =
+ _mm_add_epi32(_mm_mullo_epi16(tmp, wt0), weighted_res);
+ d32_0 = _mm_srai_epi32(sum, DIST_PRECISION_BITS);
mul = _mm_mullo_epi16(d32_1, wt1);
weighted_res = _mm_sll_epi32(mul, left_shift);
- sum = _mm_add_epi32(_mm_loadu_si128(p + 1), weighted_res);
- d32_1 = sum;
+ tmp = _mm_loadu_si128(p + 1);
+ sum = _mm_add_epi32(_mm_mullo_epi16(tmp, wt0), weighted_res);
+ d32_1 = _mm_srai_epi32(sum, DIST_PRECISION_BITS);
mul = _mm_mullo_epi16(d32_2, wt1);
weighted_res = _mm_sll_epi32(mul, left_shift);
- sum = _mm_add_epi32(_mm_loadu_si128(p + 2), weighted_res);
- d32_2 = sum;
+ tmp = _mm_loadu_si128(p + 2);
+ sum = _mm_add_epi32(_mm_mullo_epi16(tmp, wt0), weighted_res);
+ d32_2 = _mm_srai_epi32(sum, DIST_PRECISION_BITS);
mul = _mm_mullo_epi16(d32_3, wt1);
weighted_res = _mm_sll_epi32(mul, left_shift);
- sum = _mm_add_epi32(_mm_loadu_si128(p + 3), weighted_res);
- d32_3 = sum;
+ tmp = _mm_loadu_si128(p + 3);
+ sum = _mm_add_epi32(_mm_mullo_epi16(tmp, wt0), weighted_res);
+ d32_3 = _mm_srai_epi32(sum, DIST_PRECISION_BITS);
} else {
- d32_0 = _mm_sll_epi32(_mm_mullo_epi16(d32_0, wt0), left_shift);
- d32_1 = _mm_sll_epi32(_mm_mullo_epi16(d32_1, wt0), left_shift);
- d32_2 = _mm_sll_epi32(_mm_mullo_epi16(d32_2, wt0), left_shift);
- d32_3 = _mm_sll_epi32(_mm_mullo_epi16(d32_3, wt0), left_shift);
+ d32_0 = _mm_sll_epi32(d32_0, left_shift);
+ d32_1 = _mm_sll_epi32(d32_1, left_shift);
+ d32_2 = _mm_sll_epi32(d32_2, left_shift);
+ d32_3 = _mm_sll_epi32(d32_3, left_shift);
}
} else {
if (do_average) {
- d32_0 = _mm_add_epi32(_mm_loadu_si128(p + 0),
- _mm_sll_epi32(d32_0, left_shift));
- d32_1 = _mm_add_epi32(_mm_loadu_si128(p + 1),
- _mm_sll_epi32(d32_1, left_shift));
- d32_2 = _mm_add_epi32(_mm_loadu_si128(p + 2),
- _mm_sll_epi32(d32_2, left_shift));
- d32_3 = _mm_add_epi32(_mm_loadu_si128(p + 3),
- _mm_sll_epi32(d32_3, left_shift));
+ d32_0 =
+ _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p + 0),
+ _mm_sll_epi32(d32_0, left_shift)),
+ 1);
+ d32_1 =
+ _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p + 1),
+ _mm_sll_epi32(d32_1, left_shift)),
+ 1);
+ d32_2 =
+ _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p + 2),
+ _mm_sll_epi32(d32_2, left_shift)),
+ 1);
+ d32_3 =
+ _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p + 3),
+ _mm_sll_epi32(d32_3, left_shift)),
+ 1);
} else {
d32_0 = _mm_sll_epi32(d32_0, left_shift);
d32_1 = _mm_sll_epi32(d32_1, left_shift);
@@ -769,23 +800,30 @@
if (do_average) {
__m128i mul = _mm_mullo_epi16(d32_0, wt1);
__m128i weighted_res = _mm_sll_epi32(mul, left_shift);
- __m128i sum = _mm_add_epi32(_mm_loadu_si128(p + 0), weighted_res);
- d32_0 = sum;
+ __m128i tmp = _mm_loadu_si128(p + 0);
+ __m128i sum =
+ _mm_add_epi32(_mm_mullo_epi16(tmp, wt0), weighted_res);
+ d32_0 = _mm_srai_epi32(sum, DIST_PRECISION_BITS);
mul = _mm_mullo_epi16(d32_1, wt1);
weighted_res = _mm_sll_epi32(mul, left_shift);
- sum = _mm_add_epi32(_mm_loadu_si128(p + 1), weighted_res);
- d32_1 = sum;
+ tmp = _mm_loadu_si128(p + 1);
+ sum = _mm_add_epi32(_mm_mullo_epi16(tmp, wt0), weighted_res);
+ d32_1 = _mm_srai_epi32(sum, DIST_PRECISION_BITS);
} else {
- d32_0 = _mm_sll_epi32(_mm_mullo_epi16(d32_0, wt0), left_shift);
- d32_1 = _mm_sll_epi32(_mm_mullo_epi16(d32_1, wt0), left_shift);
+ d32_0 = _mm_sll_epi32(d32_0, left_shift);
+ d32_1 = _mm_sll_epi32(d32_1, left_shift);
}
} else {
if (do_average) {
- d32_0 = _mm_add_epi32(_mm_loadu_si128(p + 0),
- _mm_sll_epi32(d32_0, left_shift));
- d32_1 = _mm_add_epi32(_mm_loadu_si128(p + 1),
- _mm_sll_epi32(d32_1, left_shift));
+ d32_0 =
+ _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p + 0),
+ _mm_sll_epi32(d32_0, left_shift)),
+ 1);
+ d32_1 =
+ _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p + 1),
+ _mm_sll_epi32(d32_1, left_shift)),
+ 1);
} else {
d32_0 = _mm_sll_epi32(d32_0, left_shift);
d32_1 = _mm_sll_epi32(d32_1, left_shift);
@@ -810,15 +848,19 @@
if (do_average) {
__m128i mul = _mm_mullo_epi16(d32_0, wt1);
__m128i weighted_res = _mm_sll_epi32(mul, left_shift);
- __m128i sum = _mm_add_epi32(_mm_loadu_si128(p + 0), weighted_res);
- d32_0 = sum;
+ __m128i tmp = _mm_loadu_si128(p + 0);
+ __m128i sum =
+ _mm_add_epi32(_mm_mullo_epi16(tmp, wt0), weighted_res);
+ d32_0 = _mm_srai_epi32(sum, DIST_PRECISION_BITS);
} else {
- d32_0 = _mm_sll_epi32(_mm_mullo_epi16(d32_0, wt0), left_shift);
+ d32_0 = _mm_sll_epi32(d32_0, left_shift);
}
} else {
if (do_average) {
- d32_0 = _mm_add_epi32(_mm_loadu_si128(p + 0),
- _mm_sll_epi32(d32_0, left_shift));
+ d32_0 =
+ _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p + 0),
+ _mm_sll_epi32(d32_0, left_shift)),
+ 1);
} else {
d32_0 = _mm_sll_epi32(d32_0, left_shift);
}
@@ -841,15 +883,19 @@
if (do_average) {
__m128i mul = _mm_mullo_epi16(d32_0, wt1);
__m128i weighted_res = _mm_sll_epi32(mul, left_shift);
- __m128i sum = _mm_add_epi32(_mm_loadl_epi64(p), weighted_res);
- d32_0 = sum;
+ __m128i tmp = _mm_loadl_epi64(p);
+ __m128i sum =
+ _mm_add_epi32(_mm_mullo_epi16(tmp, wt0), weighted_res);
+ d32_0 = _mm_srai_epi32(sum, DIST_PRECISION_BITS);
} else {
- d32_0 = _mm_sll_epi32(_mm_mullo_epi16(d32_0, wt0), left_shift);
+ d32_0 = _mm_sll_epi32(d32_0, left_shift);
}
} else {
if (do_average) {
- d32_0 = _mm_add_epi32(_mm_loadl_epi64(p),
- _mm_sll_epi32(d32_0, left_shift));
+ d32_0 =
+ _mm_srai_epi32(_mm_add_epi32(_mm_loadl_epi64(p),
+ _mm_sll_epi32(d32_0, left_shift)),
+ 1);
} else {
d32_0 = _mm_sll_epi32(d32_0, left_shift);
}
diff --git a/av1/common/x86/convolve_avx2.c b/av1/common/x86/convolve_avx2.c
index 63f6138..ffc870b 100644
--- a/av1/common/x86/convolve_avx2.c
+++ b/av1/common/x86/convolve_avx2.c
@@ -452,7 +452,8 @@
_mm256_add_epi32(res_lo_0_shift, round_const), round_shift);
// Accumulate values into the destination buffer
- add_store_aligned(&dst[i * dst_stride + j], &res_lo_0_round, &avg_mask);
+ add_store_aligned(&dst[i * dst_stride + j], &res_lo_0_round, &avg_mask,
+ conv_params->do_average);
const __m256i res_lo_1_32b =
_mm256_cvtepi16_epi32(_mm256_extracti128_si256(res_lo, 1));
@@ -462,7 +463,7 @@
_mm256_add_epi32(res_lo_1_shift, round_const), round_shift);
add_store_aligned(&dst[i * dst_stride + j + dst_stride], &res_lo_1_round,
- &avg_mask);
+ &avg_mask, conv_params->do_average);
if (w - j > 8) {
const __m256i res_hi = convolve_lowbd(s + 4, coeffs);
@@ -475,7 +476,7 @@
_mm256_add_epi32(res_hi_0_shift, round_const), round_shift);
add_store_aligned(&dst[i * dst_stride + j + 8], &res_hi_0_round,
- &avg_mask);
+ &avg_mask, conv_params->do_average);
const __m256i res_hi_1_32b =
_mm256_cvtepi16_epi32(_mm256_extracti128_si256(res_hi, 1));
@@ -485,7 +486,7 @@
_mm256_add_epi32(res_hi_1_shift, round_const), round_shift);
add_store_aligned(&dst[i * dst_stride + j + 8 + dst_stride],
- &res_hi_1_round, &avg_mask);
+ &res_hi_1_round, &avg_mask, conv_params->do_average);
}
s[0] = s[1];
s[1] = s[2];
@@ -711,10 +712,11 @@
const __m256i res_hi_shift = _mm256_slli_epi32(res_hi_round, bits);
// Accumulate values into the destination buffer
- add_store_aligned(&dst[i * dst_stride + j], &res_lo_shift, &avg_mask);
+ add_store_aligned(&dst[i * dst_stride + j], &res_lo_shift, &avg_mask,
+ conv_params->do_average);
if (w - j > 8) {
add_store_aligned(&dst[i * dst_stride + j + 8], &res_hi_shift,
- &avg_mask);
+ &avg_mask, conv_params->do_average);
}
}
}
diff --git a/av1/common/x86/convolve_sse2.c b/av1/common/x86/convolve_sse2.c
index ab35226..a03f0ef 100644
--- a/av1/common/x86/convolve_sse2.c
+++ b/av1/common/x86/convolve_sse2.c
@@ -75,11 +75,12 @@
}
static INLINE void add_store(CONV_BUF_TYPE *const dst, const __m128i *const res,
- const __m128i *const avg_mask) {
+ const __m128i *const avg_mask, int shift) {
__m128i d;
d = _mm_load_si128((__m128i *)dst);
d = _mm_and_si128(d, *avg_mask);
d = _mm_add_epi32(d, *res);
+ if (shift) d = _mm_srai_epi32(d, 1);
_mm_store_si128((__m128i *)dst, d);
}
@@ -141,7 +142,7 @@
res_shift = _mm_sll_epi32(res, left_shift);
res_shift =
_mm_sra_epi32(_mm_add_epi32(res_shift, round_const), round_shift);
- add_store(dst, &res_shift, &avg_mask);
+ add_store(dst, &res_shift, &avg_mask, conv_params->do_average);
src_ptr += src_stride;
dst += dst_stride;
@@ -149,7 +150,7 @@
res_shift = _mm_sll_epi32(res, left_shift);
res_shift =
_mm_sra_epi32(_mm_add_epi32(res_shift, round_const), round_shift);
- add_store(dst, &res_shift, &avg_mask);
+ add_store(dst, &res_shift, &avg_mask, conv_params->do_average);
src_ptr += src_stride;
dst += dst_stride;
@@ -204,8 +205,10 @@
round_shift);
res_hi_shift = _mm_sra_epi32(_mm_add_epi32(res_hi_shift, round_const),
round_shift);
- add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask);
- add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &avg_mask);
+ add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask,
+ conv_params->do_average);
+ add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &avg_mask,
+ conv_params->do_average);
i++;
res_lo = convolve_lo_y(s + 1, coeffs); // Filter low index pixels
@@ -216,8 +219,10 @@
round_shift);
res_hi_shift = _mm_sra_epi32(_mm_add_epi32(res_hi_shift, round_const),
round_shift);
- add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask);
- add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &avg_mask);
+ add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask,
+ conv_params->do_average);
+ add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &avg_mask,
+ conv_params->do_average);
i++;
s[0] = s[2];
@@ -276,7 +281,7 @@
const __m128i res_lo_shift = _mm_sll_epi32(res_lo_round, left_shift);
// Accumulate values into the destination buffer
- add_store(dst, &res_lo_shift, &avg_mask);
+ add_store(dst, &res_lo_shift, &avg_mask, conv_params->do_average);
src_ptr += src_stride;
dst += dst_stride;
} while (--h);
@@ -315,8 +320,10 @@
const __m128i res_hi_shift = _mm_sll_epi32(res_hi_round, left_shift);
// Accumulate values into the destination buffer
- add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask);
- add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &avg_mask);
+ add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask,
+ conv_params->do_average);
+ add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &avg_mask,
+ conv_params->do_average);
j += 8;
} while (j < w);
} while (++i < h);
diff --git a/av1/common/x86/highbd_convolve_2d_avx2.c b/av1/common/x86/highbd_convolve_2d_avx2.c
index cb13e9b..73c2bea 100644
--- a/av1/common/x86/highbd_convolve_2d_avx2.c
+++ b/av1/common/x86/highbd_convolve_2d_avx2.c
@@ -374,18 +374,26 @@
__m128i *const p = (__m128i *)&dst[i * dst_stride + j];
if (do_average) {
_mm_storeu_si128(
- p + 0, _mm_add_epi32(_mm_loadu_si128(p + 0),
- _mm256_extractf128_si256(res_lo_round, 0)));
+ p + 0, _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p + 0),
+ _mm256_extractf128_si256(
+ res_lo_round, 0)),
+ 1));
_mm_storeu_si128(
- p + 1, _mm_add_epi32(_mm_loadu_si128(p + 1),
- _mm256_extractf128_si256(res_hi_round, 0)));
+ p + 1, _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p + 1),
+ _mm256_extractf128_si256(
+ res_hi_round, 0)),
+ 1));
if (w - j > 8) {
- _mm_storeu_si128(p + 2, _mm_add_epi32(_mm_loadu_si128(p + 2),
- _mm256_extractf128_si256(
- res_lo_round, 1)));
- _mm_storeu_si128(p + 3, _mm_add_epi32(_mm_loadu_si128(p + 3),
- _mm256_extractf128_si256(
- res_hi_round, 1)));
+ _mm_storeu_si128(
+ p + 2, _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p + 2),
+ _mm256_extractf128_si256(
+ res_lo_round, 1)),
+ 1));
+ _mm_storeu_si128(
+ p + 3, _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p + 3),
+ _mm256_extractf128_si256(
+ res_hi_round, 1)),
+ 1));
}
} else {
_mm_storeu_si128(p + 0, _mm256_extractf128_si256(res_lo_round, 0));
diff --git a/av1/common/x86/highbd_convolve_2d_sse4.c b/av1/common/x86/highbd_convolve_2d_sse4.c
index 979d1dd..6980960 100644
--- a/av1/common/x86/highbd_convolve_2d_sse4.c
+++ b/av1/common/x86/highbd_convolve_2d_sse4.c
@@ -201,23 +201,35 @@
__m128i *const p = (__m128i *)&dst[i * dst_stride + j];
if (conv_params->use_jnt_comp_avg) {
if (do_average) {
- const __m128i jnt_sum_lo = _mm_add_epi32(
- _mm_loadu_si128(p + 0), _mm_mullo_epi32(res_lo_round, wt1));
- const __m128i jnt_sum_hi = _mm_add_epi32(
- _mm_loadu_si128(p + 1), _mm_mullo_epi32(res_hi_round, wt1));
+ const __m128i tmp_lo = _mm_loadu_si128(p + 0);
+ const __m128i tmp_hi = _mm_loadu_si128(p + 1);
+ const __m128i jnt_sum_lo =
+ _mm_add_epi32(_mm_mullo_epi32(tmp_lo, wt0),
+ _mm_mullo_epi32(res_lo_round, wt1));
+ const __m128i jnt_sum_hi =
+ _mm_add_epi32(_mm_mullo_epi32(tmp_hi, wt0),
+ _mm_mullo_epi32(res_hi_round, wt1));
+ const __m128i final_lo =
+ _mm_srai_epi32(jnt_sum_lo, DIST_PRECISION_BITS);
+ const __m128i final_hi =
+ _mm_srai_epi32(jnt_sum_hi, DIST_PRECISION_BITS);
- _mm_storeu_si128(p + 0, jnt_sum_lo);
- _mm_storeu_si128(p + 1, jnt_sum_hi);
+ _mm_storeu_si128(p + 0, final_lo);
+ _mm_storeu_si128(p + 1, final_hi);
} else {
- _mm_storeu_si128(p + 0, _mm_mullo_epi32(res_lo_round, wt0));
- _mm_storeu_si128(p + 1, _mm_mullo_epi32(res_hi_round, wt0));
+ _mm_storeu_si128(p + 0, res_lo_round);
+ _mm_storeu_si128(p + 1, res_hi_round);
}
} else {
if (do_average) {
_mm_storeu_si128(
- p + 0, _mm_add_epi32(_mm_loadu_si128(p + 0), res_lo_round));
+ p + 0,
+ _mm_srai_epi32(
+ _mm_add_epi32(_mm_loadu_si128(p + 0), res_lo_round), 1));
_mm_storeu_si128(
- p + 1, _mm_add_epi32(_mm_loadu_si128(p + 1), res_hi_round));
+ p + 1,
+ _mm_srai_epi32(
+ _mm_add_epi32(_mm_loadu_si128(p + 1), res_hi_round), 1));
} else {
_mm_storeu_si128(p + 0, res_lo_round);
_mm_storeu_si128(p + 1, res_hi_round);
diff --git a/av1/common/x86/highbd_convolve_2d_ssse3.c b/av1/common/x86/highbd_convolve_2d_ssse3.c
index ee948f8..ce348ac 100644
--- a/av1/common/x86/highbd_convolve_2d_ssse3.c
+++ b/av1/common/x86/highbd_convolve_2d_ssse3.c
@@ -192,10 +192,14 @@
// Accumulate values into the destination buffer
__m128i *const p = (__m128i *)&dst[i * dst_stride + j];
if (do_average) {
- _mm_storeu_si128(p + 0,
- _mm_add_epi32(_mm_loadu_si128(p + 0), res_lo_round));
- _mm_storeu_si128(p + 1,
- _mm_add_epi32(_mm_loadu_si128(p + 1), res_hi_round));
+ _mm_storeu_si128(
+ p + 0,
+ _mm_srai_epi32(
+ _mm_add_epi32(_mm_loadu_si128(p + 0), res_lo_round), 1));
+ _mm_storeu_si128(
+ p + 1,
+ _mm_srai_epi32(
+ _mm_add_epi32(_mm_loadu_si128(p + 1), res_hi_round), 1));
} else {
_mm_storeu_si128(p + 0, res_lo_round);
_mm_storeu_si128(p + 1, res_hi_round);
diff --git a/av1/common/x86/highbd_warp_plane_sse4.c b/av1/common/x86/highbd_warp_plane_sse4.c
index 5647eb3..4ebd8a6 100644
--- a/av1/common/x86/highbd_warp_plane_sse4.c
+++ b/av1/common/x86/highbd_warp_plane_sse4.c
@@ -309,19 +309,22 @@
#if CONFIG_JNT_COMP
if (conv_params->use_jnt_comp_avg) {
if (comp_avg) {
- const __m128i sum = _mm_add_epi32(_mm_loadu_si128(p),
- _mm_mullo_epi32(res_lo, wt1));
- res_lo = sum;
- } else {
- res_lo = _mm_mullo_epi32(res_lo, wt0);
+ const __m128i sum =
+ _mm_add_epi32(_mm_mullo_epi32(_mm_loadu_si128(p), wt0),
+ _mm_mullo_epi32(res_lo, wt1));
+ res_lo = _mm_srai_epi32(sum, DIST_PRECISION_BITS);
}
} else {
- if (comp_avg) res_lo = _mm_add_epi32(_mm_loadu_si128(p), res_lo);
+ if (comp_avg)
+ res_lo =
+ _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p), res_lo), 1);
}
_mm_storeu_si128(p, res_lo);
#else
- if (comp_avg) res_lo = _mm_add_epi32(_mm_loadu_si128(p), res_lo);
+ if (comp_avg)
+ res_lo =
+ _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p), res_lo), 1);
_mm_storeu_si128(p, res_lo);
#endif
@@ -332,21 +335,22 @@
#if CONFIG_JNT_COMP
if (conv_params->use_jnt_comp_avg) {
if (comp_avg) {
- const __m128i sum = _mm_add_epi32(_mm_loadu_si128(p + 1),
- _mm_mullo_epi32(res_hi, wt1));
- res_hi = sum;
- } else {
- res_hi = _mm_mullo_epi32(res_hi, wt0);
+ const __m128i sum =
+ _mm_add_epi32(_mm_mullo_epi32(_mm_loadu_si128(p + 1), wt0),
+ _mm_mullo_epi32(res_hi, wt1));
+ res_hi = _mm_srai_epi32(sum, DIST_PRECISION_BITS);
}
} else {
if (comp_avg)
- res_hi = _mm_add_epi32(_mm_loadu_si128(p + 1), res_hi);
+ res_hi = _mm_srai_epi32(
+ _mm_add_epi32(_mm_loadu_si128(p + 1), res_hi), 1);
}
_mm_storeu_si128(p + 1, res_hi);
#else
if (comp_avg)
- res_hi = _mm_add_epi32(_mm_loadu_si128(p + 1), res_hi);
+ res_hi = _mm_srai_epi32(
+ _mm_add_epi32(_mm_loadu_si128(p + 1), res_hi), 1);
_mm_storeu_si128(p + 1, res_hi);
#endif
}
diff --git a/av1/common/x86/jnt_convolve_sse4.c b/av1/common/x86/jnt_convolve_sse4.c
index bc23365..54bef5a 100644
--- a/av1/common/x86/jnt_convolve_sse4.c
+++ b/av1/common/x86/jnt_convolve_sse4.c
@@ -76,23 +76,28 @@
}
static INLINE void add_store(CONV_BUF_TYPE *const dst, const __m128i *const res,
- const __m128i *const avg_mask) {
+ const __m128i *const avg_mask, int shift) {
__m128i d;
d = _mm_load_si128((__m128i *)dst);
d = _mm_and_si128(d, *avg_mask);
d = _mm_add_epi32(d, *res);
+ if (shift) d = _mm_srai_epi32(d, 1);
_mm_store_si128((__m128i *)dst, d);
}
#if CONFIG_JNT_COMP
static INLINE void mult_add_store(CONV_BUF_TYPE *const dst,
const __m128i *const res,
- const __m128i *const avg_mask,
- const __m128i *const wt) {
+ const __m128i *const wt0,
+ const __m128i *const wt1, int do_average) {
__m128i d;
- d = _mm_load_si128((__m128i *)dst);
- d = _mm_and_si128(d, *avg_mask);
- d = _mm_add_epi32(d, _mm_mullo_epi32(*res, *wt));
+ if (do_average) {
+ d = _mm_load_si128((__m128i *)dst);
+ d = _mm_add_epi32(_mm_mullo_epi32(d, *wt0), _mm_mullo_epi32(*res, *wt1));
+ d = _mm_srai_epi32(d, DIST_PRECISION_BITS);
+ } else {
+ d = *res;
+ }
_mm_store_si128((__m128i *)dst, d);
}
@@ -111,7 +116,6 @@
const __m128i avg_mask = _mm_set1_epi32(conv_params->do_average ? -1 : 0);
const __m128i wt0 = _mm_set1_epi32(conv_params->fwd_offset);
const __m128i wt1 = _mm_set1_epi32(conv_params->bck_offset);
- const __m128i wt = conv_params->do_average ? wt1 : wt0;
const __m128i round_const = _mm_set1_epi32((1 << conv_params->round_1) >> 1);
const __m128i round_shift = _mm_cvtsi32_si128(conv_params->round_1);
__m128i coeffs[4];
@@ -156,9 +160,9 @@
res_shift =
_mm_sra_epi32(_mm_add_epi32(res_shift, round_const), round_shift);
if (conv_params->use_jnt_comp_avg)
- mult_add_store(dst, &res_shift, &avg_mask, &wt);
+ mult_add_store(dst, &res_shift, &wt0, &wt1, conv_params->do_average);
else
- add_store(dst, &res_shift, &avg_mask);
+ add_store(dst, &res_shift, &avg_mask, conv_params->do_average);
src_ptr += src_stride;
dst += dst_stride;
@@ -167,9 +171,9 @@
res_shift =
_mm_sra_epi32(_mm_add_epi32(res_shift, round_const), round_shift);
if (conv_params->use_jnt_comp_avg)
- mult_add_store(dst, &res_shift, &avg_mask, &wt);
+ mult_add_store(dst, &res_shift, &wt0, &wt1, conv_params->do_average);
else
- add_store(dst, &res_shift, &avg_mask);
+ add_store(dst, &res_shift, &avg_mask, conv_params->do_average);
src_ptr += src_stride;
dst += dst_stride;
@@ -225,13 +229,15 @@
res_hi_shift = _mm_sra_epi32(_mm_add_epi32(res_hi_shift, round_const),
round_shift);
if (conv_params->use_jnt_comp_avg) {
- mult_add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask,
- &wt);
- mult_add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &avg_mask,
- &wt);
+ mult_add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &wt0,
+ &wt1, conv_params->do_average);
+ mult_add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &wt0,
+ &wt1, conv_params->do_average);
} else {
- add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask);
- add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &avg_mask);
+ add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask,
+ conv_params->do_average);
+ add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &avg_mask,
+ conv_params->do_average);
}
i++;
@@ -244,13 +250,15 @@
res_hi_shift = _mm_sra_epi32(_mm_add_epi32(res_hi_shift, round_const),
round_shift);
if (conv_params->use_jnt_comp_avg) {
- mult_add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask,
- &wt);
- mult_add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &avg_mask,
- &wt);
+ mult_add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &wt0,
+ &wt1, conv_params->do_average);
+ mult_add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &wt0,
+ &wt1, conv_params->do_average);
} else {
- add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask);
- add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &avg_mask);
+ add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask,
+ conv_params->do_average);
+ add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &avg_mask,
+ conv_params->do_average);
}
i++;
@@ -285,7 +293,6 @@
const int w1 = conv_params->bck_offset;
const __m128i wt0 = _mm_set1_epi32(w0);
const __m128i wt1 = _mm_set1_epi32(w1);
- const __m128i wt = conv_params->do_average ? wt1 : wt0;
__m128i coeffs[4];
(void)filter_params_y;
@@ -314,9 +321,9 @@
// Accumulate values into the destination buffer
if (conv_params->use_jnt_comp_avg)
- mult_add_store(dst, &res_lo_shift, &avg_mask, &wt);
+ mult_add_store(dst, &res_lo_shift, &wt0, &wt1, conv_params->do_average);
else
- add_store(dst, &res_lo_shift, &avg_mask);
+ add_store(dst, &res_lo_shift, &avg_mask, conv_params->do_average);
src_ptr += src_stride;
dst += dst_stride;
} while (--h);
@@ -356,13 +363,15 @@
// Accumulate values into the destination buffer
if (conv_params->use_jnt_comp_avg) {
- mult_add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask,
- &wt);
- mult_add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &avg_mask,
- &wt);
+ mult_add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &wt0,
+ &wt1, conv_params->do_average);
+ mult_add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &wt0,
+ &wt1, conv_params->do_average);
} else {
- add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask);
- add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &avg_mask);
+ add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask,
+ conv_params->do_average);
+ add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &avg_mask,
+ conv_params->do_average);
}
j += 8;
} while (j < w);
@@ -553,24 +562,34 @@
// original c function at: av1/common/convolve.c: av1_convolve_2d_c
__m128i *const p = (__m128i *)&dst[i * dst_stride + j];
if (do_average) {
- _mm_storeu_si128(p + 0,
- _mm_add_epi32(_mm_loadu_si128(p + 0),
- _mm_mullo_epi32(res_lo_round, wt1)));
- _mm_storeu_si128(p + 1,
- _mm_add_epi32(_mm_loadu_si128(p + 1),
- _mm_mullo_epi32(res_hi_round, wt1)));
+ _mm_storeu_si128(
+ p + 0,
+ _mm_srai_epi32(
+ _mm_add_epi32(_mm_mullo_epi32(_mm_loadu_si128(p + 0), wt0),
+ _mm_mullo_epi32(res_lo_round, wt1)),
+ DIST_PRECISION_BITS));
+ _mm_storeu_si128(
+ p + 1,
+ _mm_srai_epi32(
+ _mm_add_epi32(_mm_mullo_epi32(_mm_loadu_si128(p + 1), wt0),
+ _mm_mullo_epi32(res_hi_round, wt1)),
+ DIST_PRECISION_BITS));
} else {
- _mm_storeu_si128(p + 0, _mm_mullo_epi32(res_lo_round, wt0));
- _mm_storeu_si128(p + 1, _mm_mullo_epi32(res_hi_round, wt0));
+ _mm_storeu_si128(p + 0, res_lo_round);
+ _mm_storeu_si128(p + 1, res_hi_round);
}
} else {
// Accumulate values into the destination buffer
__m128i *const p = (__m128i *)&dst[i * dst_stride + j];
if (do_average) {
_mm_storeu_si128(
- p + 0, _mm_add_epi32(_mm_loadu_si128(p + 0), res_lo_round));
+ p + 0,
+ _mm_srai_epi32(
+ _mm_add_epi32(_mm_loadu_si128(p + 0), res_lo_round), 1));
_mm_storeu_si128(
- p + 1, _mm_add_epi32(_mm_loadu_si128(p + 1), res_hi_round));
+ p + 1,
+ _mm_srai_epi32(
+ _mm_add_epi32(_mm_loadu_si128(p + 1), res_hi_round), 1));
} else {
_mm_storeu_si128(p + 0, res_lo_round);
_mm_storeu_si128(p + 1, res_hi_round);
diff --git a/av1/common/x86/warp_plane_sse4.c b/av1/common/x86/warp_plane_sse4.c
index 1e8ad47..b05b3b8 100644
--- a/av1/common/x86/warp_plane_sse4.c
+++ b/av1/common/x86/warp_plane_sse4.c
@@ -484,18 +484,21 @@
#if CONFIG_JNT_COMP
if (conv_params->use_jnt_comp_avg) {
if (comp_avg) {
- res_lo = _mm_add_epi32(_mm_loadu_si128(p),
+ res_lo = _mm_add_epi32(_mm_mullo_epi32(_mm_loadu_si128(p), wt0),
_mm_mullo_epi32(res_lo, wt1));
- } else {
- res_lo = _mm_mullo_epi32(res_lo, wt0);
+ res_lo = _mm_srai_epi32(res_lo, DIST_PRECISION_BITS);
}
} else {
- if (comp_avg) res_lo = _mm_add_epi32(_mm_loadu_si128(p), res_lo);
+ if (comp_avg)
+ res_lo =
+ _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p), res_lo), 1);
}
_mm_storeu_si128(p, res_lo);
#else
- if (comp_avg) res_lo = _mm_add_epi32(_mm_loadu_si128(p), res_lo);
+ if (comp_avg)
+ res_lo =
+ _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p), res_lo), 1);
_mm_storeu_si128(p, res_lo);
#endif
if (p_width > 4) {
@@ -504,20 +507,22 @@
#if CONFIG_JNT_COMP
if (conv_params->use_jnt_comp_avg) {
if (comp_avg) {
- res_hi = _mm_add_epi32(_mm_loadu_si128(p + 1),
- _mm_mullo_epi32(res_hi, wt1));
- } else {
- res_hi = _mm_mullo_epi32(res_hi, wt0);
+ res_hi =
+ _mm_add_epi32(_mm_mullo_epi32(_mm_loadu_si128(p + 1), wt0),
+ _mm_mullo_epi32(res_hi, wt1));
+ res_hi = _mm_srai_epi32(res_hi, DIST_PRECISION_BITS);
}
} else {
if (comp_avg)
- res_hi = _mm_add_epi32(_mm_loadu_si128(p + 1), res_hi);
+ res_hi = _mm_srai_epi32(
+ _mm_add_epi32(_mm_loadu_si128(p + 1), res_hi), 1);
}
_mm_storeu_si128(p + 1, res_hi);
#else
if (comp_avg)
- res_hi = _mm_add_epi32(_mm_loadu_si128(p + 1), res_hi);
+ res_hi = _mm_srai_epi32(
+ _mm_add_epi32(_mm_loadu_si128(p + 1), res_hi), 1);
_mm_storeu_si128(p + 1, res_hi);
#endif
}