Propagate constants in Neon dist_wtd_convolve_2d_vert functions
Rounding modes and other convolution parameters are known ahead of
time. This patch propagates the values into the Neon code paths for
dist_wtd_convolve_2d_vert_[6/8]tap_neon - enabling us to make some
useful simplifications and optimizations.
Co-authored by: Jonathan Wright <jonathan.wright@arm.com>
Change-Id: Ie3753db63a1362a51f67cbdc73d2cacd8516a948
diff --git a/av1/common/arm/convolve_neon.h b/av1/common/arm/convolve_neon.h
index 59c77b0..ee12d13 100644
--- a/av1/common/arm/convolve_neon.h
+++ b/av1/common/arm/convolve_neon.h
@@ -483,7 +483,6 @@
const int16x4_t s2, const int16x4_t s3,
const int16x4_t s4, const int16x4_t s5,
const int16x8_t y_filter,
- const int32x4_t round_shift_vec,
const int32x4_t offset_const) {
const int16x4_t y_filter_lo = vget_low_s16(y_filter);
const int16x4_t y_filter_hi = vget_high_s16(y_filter);
@@ -496,15 +495,13 @@
sum = vmlal_lane_s16(sum, s4, y_filter_hi, 1);
sum = vmlal_lane_s16(sum, s5, y_filter_hi, 2);
- sum = vqrshlq_s32(sum, round_shift_vec);
- return vqmovun_s32(sum);
+ return vqrshrun_n_s32(sum, COMPOUND_ROUND1_BITS);
}
static INLINE uint16x8_t convolve6_8_s32(const int16x8_t s0, const int16x8_t s1,
const int16x8_t s2, const int16x8_t s3,
const int16x8_t s4, const int16x8_t s5,
const int16x8_t y_filter,
- const int32x4_t round_shift_vec,
const int32x4_t offset_const) {
const int16x4_t y_filter_lo = vget_low_s16(y_filter);
const int16x4_t y_filter_hi = vget_high_s16(y_filter);
@@ -525,9 +522,8 @@
sum1 = vmlal_lane_s16(sum1, vget_high_s16(s4), y_filter_hi, 1);
sum1 = vmlal_lane_s16(sum1, vget_high_s16(s5), y_filter_hi, 2);
- sum0 = vqrshlq_s32(sum0, round_shift_vec);
- sum1 = vqrshlq_s32(sum1, round_shift_vec);
- return vcombine_u16(vqmovun_s32(sum0), vqmovun_s32(sum1));
+ return vcombine_u16(vqrshrun_n_s32(sum0, COMPOUND_ROUND1_BITS),
+ vqrshrun_n_s32(sum1, COMPOUND_ROUND1_BITS));
}
static INLINE uint16x4_t convolve8_4_s32(const int16x4_t s0, const int16x4_t s1,
@@ -535,7 +531,6 @@
const int16x4_t s4, const int16x4_t s5,
const int16x4_t s6, const int16x4_t s7,
const int16x8_t y_filter,
- const int32x4_t round_shift_vec,
const int32x4_t offset_const) {
const int16x4_t y_filter_lo = vget_low_s16(y_filter);
const int16x4_t y_filter_hi = vget_high_s16(y_filter);
@@ -550,8 +545,7 @@
sum = vmlal_lane_s16(sum, s6, y_filter_hi, 2);
sum = vmlal_lane_s16(sum, s7, y_filter_hi, 3);
- sum = vqrshlq_s32(sum, round_shift_vec);
- return vqmovun_s32(sum);
+ return vqrshrun_n_s32(sum, COMPOUND_ROUND1_BITS);
}
static INLINE uint16x8_t convolve8_8_s32(const int16x8_t s0, const int16x8_t s1,
@@ -559,7 +553,6 @@
const int16x8_t s4, const int16x8_t s5,
const int16x8_t s6, const int16x8_t s7,
const int16x8_t y_filter,
- const int32x4_t round_shift_vec,
const int32x4_t offset_const) {
const int16x4_t y_filter_lo = vget_low_s16(y_filter);
const int16x4_t y_filter_hi = vget_high_s16(y_filter);
@@ -584,9 +577,8 @@
sum1 = vmlal_lane_s16(sum1, vget_high_s16(s6), y_filter_hi, 2);
sum1 = vmlal_lane_s16(sum1, vget_high_s16(s7), y_filter_hi, 3);
- sum0 = vqrshlq_s32(sum0, round_shift_vec);
- sum1 = vqrshlq_s32(sum1, round_shift_vec);
- return vcombine_u16(vqmovun_s32(sum0), vqmovun_s32(sum1));
+ return vcombine_u16(vqrshrun_n_s32(sum0, COMPOUND_ROUND1_BITS),
+ vqrshrun_n_s32(sum1, COMPOUND_ROUND1_BITS));
}
#if !(defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD))
diff --git a/av1/common/arm/jnt_convolve_neon.c b/av1/common/arm/jnt_convolve_neon.c
index bd868ea..67d5eee 100644
--- a/av1/common/arm/jnt_convolve_neon.c
+++ b/av1/common/arm/jnt_convolve_neon.c
@@ -709,14 +709,13 @@
const int dst_stride = conv_params->dst_stride;
const int bd = 8;
- const int offset_bits = bd + 2 * FILTER_BITS - conv_params->round_0;
- const int16_t sub_const = (1 << (offset_bits - conv_params->round_1)) +
- (1 << (offset_bits - conv_params->round_1 - 1));
+ const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
+ const int16_t sub_const = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
+ (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
const int16_t round_bits =
- 2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
- const int offset = bd + 2 * FILTER_BITS - conv_params->round_0;
- const int32x4_t round_shift_vec = vdupq_n_s32(-(conv_params->round_1));
+ 2 * FILTER_BITS - ROUND0_BITS - COMPOUND_ROUND1_BITS;
+ const int offset = bd + 2 * FILTER_BITS - ROUND0_BITS;
const int32x4_t offset_const = vdupq_n_s32(1 << offset);
const int16x4_t sub_const_vec = vdup_n_s16(sub_const);
const uint16_t fwd_offset = conv_params->fwd_offset;
@@ -742,14 +741,10 @@
#if defined(__aarch64__)
load_s16_4x4(src_ptr, src_stride, &s5, &s6, &s7, &s8);
- d0 = convolve6_4_s32(s0, s1, s2, s3, s4, s5, y_filter, round_shift_vec,
- offset_const);
- d1 = convolve6_4_s32(s1, s2, s3, s4, s5, s6, y_filter, round_shift_vec,
- offset_const);
- d2 = convolve6_4_s32(s2, s3, s4, s5, s6, s7, y_filter, round_shift_vec,
- offset_const);
- d3 = convolve6_4_s32(s3, s4, s5, s6, s7, s8, y_filter, round_shift_vec,
- offset_const);
+ d0 = convolve6_4_s32(s0, s1, s2, s3, s4, s5, y_filter, offset_const);
+ d1 = convolve6_4_s32(s1, s2, s3, s4, s5, s6, y_filter, offset_const);
+ d2 = convolve6_4_s32(s2, s3, s4, s5, s6, s7, y_filter, offset_const);
+ d3 = convolve6_4_s32(s3, s4, s5, s6, s7, s8, y_filter, offset_const);
if (do_average) {
load_u16_4x4(dst_ptr, dst_stride, &dd0, &dd1, &dd2, &dd3);
@@ -778,8 +773,7 @@
#else
s5 = vld1_s16(src_ptr);
- d0 = convolve6_4_s32(s0, s1, s2, s3, s4, s5, y_filter, round_shift_vec,
- offset_const);
+ d0 = convolve6_4_s32(s0, s1, s2, s3, s4, s5, y_filter, offset_const);
if (do_average) {
dd0 = vld1_u16(dst_ptr);
@@ -828,14 +822,10 @@
#if defined(__aarch64__)
load_s16_8x4(s, src_stride, &s5, &s6, &s7, &s8);
- d0 = convolve6_8_s32(s0, s1, s2, s3, s4, s5, y_filter, round_shift_vec,
- offset_const);
- d1 = convolve6_8_s32(s1, s2, s3, s4, s5, s6, y_filter, round_shift_vec,
- offset_const);
- d2 = convolve6_8_s32(s2, s3, s4, s5, s6, s7, y_filter, round_shift_vec,
- offset_const);
- d3 = convolve6_8_s32(s3, s4, s5, s6, s7, s8, y_filter, round_shift_vec,
- offset_const);
+ d0 = convolve6_8_s32(s0, s1, s2, s3, s4, s5, y_filter, offset_const);
+ d1 = convolve6_8_s32(s1, s2, s3, s4, s5, s6, y_filter, offset_const);
+ d2 = convolve6_8_s32(s2, s3, s4, s5, s6, s7, y_filter, offset_const);
+ d3 = convolve6_8_s32(s3, s4, s5, s6, s7, s8, y_filter, offset_const);
if (do_average) {
load_u16_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
@@ -862,8 +852,7 @@
#else
s5 = vld1q_s16(s);
- d0 = convolve6_8_s32(s0, s1, s2, s3, s4, s5, y_filter, round_shift_vec,
- offset_const);
+ d0 = convolve6_8_s32(s0, s1, s2, s3, s4, s5, y_filter, offset_const);
if (do_average) {
dd0 = vld1q_u16(d);
@@ -904,14 +893,13 @@
const int dst_stride = conv_params->dst_stride;
const int bd = 8;
- const int offset_bits = bd + 2 * FILTER_BITS - conv_params->round_0;
- const int16_t sub_const = (1 << (offset_bits - conv_params->round_1)) +
- (1 << (offset_bits - conv_params->round_1 - 1));
+ const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
+ const int16_t sub_const = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
+ (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
const int16_t round_bits =
- 2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
- const int offset = bd + 2 * FILTER_BITS - conv_params->round_0;
- const int32x4_t round_shift_vec = vdupq_n_s32(-(conv_params->round_1));
+ 2 * FILTER_BITS - ROUND0_BITS - COMPOUND_ROUND1_BITS;
+ const int offset = bd + 2 * FILTER_BITS - ROUND0_BITS;
const int32x4_t offset_const = vdupq_n_s32(1 << offset);
const int16x4_t sub_const_vec = vdup_n_s16(sub_const);
const uint16_t fwd_offset = conv_params->fwd_offset;
@@ -938,13 +926,13 @@
load_s16_4x4(src_ptr, src_stride, &s7, &s8, &s9, &s10);
d0 = convolve8_4_s32(s0, s1, s2, s3, s4, s5, s6, s7, y_filter,
- round_shift_vec, offset_const);
+ offset_const);
d1 = convolve8_4_s32(s1, s2, s3, s4, s5, s6, s7, s8, y_filter,
- round_shift_vec, offset_const);
+ offset_const);
d2 = convolve8_4_s32(s2, s3, s4, s5, s6, s7, s8, s9, y_filter,
- round_shift_vec, offset_const);
+ offset_const);
d3 = convolve8_4_s32(s3, s4, s5, s6, s7, s8, s9, s10, y_filter,
- round_shift_vec, offset_const);
+ offset_const);
if (do_average) {
load_u16_4x4(dst_ptr, dst_stride, &dd0, &dd1, &dd2, &dd3);
@@ -976,7 +964,7 @@
s7 = vld1_s16(src_ptr);
d0 = convolve8_4_s32(s0, s1, s2, s3, s4, s5, s6, s7, y_filter,
- round_shift_vec, offset_const);
+ offset_const);
if (do_average) {
dd0 = vld1_u16(dst_ptr);
@@ -1028,13 +1016,13 @@
load_s16_8x4(s, src_stride, &s7, &s8, &s9, &s10);
d0 = convolve8_8_s32(s0, s1, s2, s3, s4, s5, s6, s7, y_filter,
- round_shift_vec, offset_const);
+ offset_const);
d1 = convolve8_8_s32(s1, s2, s3, s4, s5, s6, s7, s8, y_filter,
- round_shift_vec, offset_const);
+ offset_const);
d2 = convolve8_8_s32(s2, s3, s4, s5, s6, s7, s8, s9, y_filter,
- round_shift_vec, offset_const);
+ offset_const);
d3 = convolve8_8_s32(s3, s4, s5, s6, s7, s8, s9, s10, y_filter,
- round_shift_vec, offset_const);
+ offset_const);
if (do_average) {
load_u16_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
@@ -1064,7 +1052,7 @@
s7 = vld1q_s16(s);
d0 = convolve8_8_s32(s0, s1, s2, s3, s4, s5, s6, s7, y_filter,
- round_shift_vec, offset_const);
+ offset_const);
if (do_average) {
dd0 = vld1q_u16(d);