Propagate constants in av1_dist_wtd_convolve_2d_copy_neon
Rounding modes and other convolution parameters are known ahead of
time. This patch propagates the values into the Neon code paths for
av1_dist_wtd_convolve_2d_copy_neon - enabling us to make some useful
simplifications and optimizations.
Co-authored by: Jonathan Wright <jonathan.wright@arm.com>
Change-Id: I8677d9c5dfae7eb3c9a6bc77d90bdf280db33291
diff --git a/av1/common/arm/jnt_convolve_neon.c b/av1/common/arm/jnt_convolve_neon.c
index 6dd4ecb..f970044 100644
--- a/av1/common/arm/jnt_convolve_neon.c
+++ b/av1/common/arm/jnt_convolve_neon.c
@@ -1135,16 +1135,14 @@
int h, ConvolveParams *conv_params) {
CONV_BUF_TYPE *dst = conv_params->dst;
const int dst_stride = conv_params->dst_stride;
- const int16_t bits =
- FILTER_BITS * 2 - conv_params->round_1 - conv_params->round_0;
const int bd = 8;
- const int offset_bits = bd + 2 * FILTER_BITS - conv_params->round_0;
- const int round_offset = (1 << (offset_bits - conv_params->round_1)) +
- (1 << (offset_bits - conv_params->round_1 - 1));
+ const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
+ const int round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
+ (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
const int16x4_t sub_const_vec = vdup_n_s16((int16_t)round_offset);
+ const uint8x8_t shift_by_bits = vdup_n_u8(1 << (FILTER_BITS - ROUND0_BITS));
if (w >= 8) {
- const int16x8_t shift_by_bits = vdupq_n_s16(bits);
const uint16x8_t round_offset_vec = vdupq_n_u16((uint16_t)round_offset);
uint8x8_t s0, s1, s2, s3, dd0, dd1, dd2, dd3;
uint16x8_t d0, d1, d2, d3, t0, t1, t2, t3;
@@ -1159,22 +1157,17 @@
do {
load_u8_8x4(s, src_stride, &s0, &s1, &s2, &s3);
- d0 = vshlq_u16(vmovl_u8(s0), shift_by_bits);
- d1 = vshlq_u16(vmovl_u8(s1), shift_by_bits);
- d2 = vshlq_u16(vmovl_u8(s2), shift_by_bits);
- d3 = vshlq_u16(vmovl_u8(s3), shift_by_bits);
-
- d0 = vaddq_u16(d0, round_offset_vec);
- d1 = vaddq_u16(d1, round_offset_vec);
- d2 = vaddq_u16(d2, round_offset_vec);
- d3 = vaddq_u16(d3, round_offset_vec);
+ d0 = vmlal_u8(round_offset_vec, s0, shift_by_bits);
+ d1 = vmlal_u8(round_offset_vec, s1, shift_by_bits);
+ d2 = vmlal_u8(round_offset_vec, s2, shift_by_bits);
+ d3 = vmlal_u8(round_offset_vec, s3, shift_by_bits);
if (conv_params->do_average) {
load_u16_8x4(d_conv_buf, dst_stride, &t0, &t1, &t2, &t3);
compute_avg_8x4(
t0, t1, t2, t3, d0, d1, d2, d3, conv_params->fwd_offset,
- conv_params->bck_offset, sub_const_vec, bits,
+ conv_params->bck_offset, sub_const_vec, FILTER_BITS - ROUND0_BITS,
conv_params->use_dist_wtd_comp_avg, &dd0, &dd1, &dd2, &dd3);
store_u8_8x4(d_u8, dst8_stride, dd0, dd1, dd2, dd3);
@@ -1191,8 +1184,7 @@
dst8 += 4 * dst8_stride;
} while (--height != 0);
} else {
- const int16x4_t shift_by_bits = vdup_n_s16(bits);
- const uint16x4_t round_offset_vec = vdup_n_u16((uint16_t)round_offset);
+ const uint16x8_t round_offset_vec = vdupq_n_u16((uint16_t)round_offset);
uint8x8_t s0, s1, s2, s3, d01, d23;
uint16x4_t d0, d1, d2, d3, t0, t1, t2, t3;
int height = h / 4;
@@ -1200,21 +1192,17 @@
do {
load_u8_8x4(src, src_stride, &s0, &s1, &s2, &s3);
- d0 = vshl_u16(vget_low_u16(vmovl_u8(s0)), shift_by_bits);
- d1 = vshl_u16(vget_low_u16(vmovl_u8(s1)), shift_by_bits);
- d2 = vshl_u16(vget_low_u16(vmovl_u8(s2)), shift_by_bits);
- d3 = vshl_u16(vget_low_u16(vmovl_u8(s3)), shift_by_bits);
-
- d0 = vadd_u16(d0, round_offset_vec);
- d1 = vadd_u16(d1, round_offset_vec);
- d2 = vadd_u16(d2, round_offset_vec);
- d3 = vadd_u16(d3, round_offset_vec);
+ d0 = vget_low_u16(vmlal_u8(round_offset_vec, s0, shift_by_bits));
+ d1 = vget_low_u16(vmlal_u8(round_offset_vec, s1, shift_by_bits));
+ d2 = vget_low_u16(vmlal_u8(round_offset_vec, s2, shift_by_bits));
+ d3 = vget_low_u16(vmlal_u8(round_offset_vec, s3, shift_by_bits));
if (conv_params->do_average) {
load_u16_4x4(dst, dst_stride, &t0, &t1, &t2, &t3);
compute_avg_4x4(t0, t1, t2, t3, d0, d1, d2, d3, conv_params->fwd_offset,
- conv_params->bck_offset, sub_const_vec, bits,
+ conv_params->bck_offset, sub_const_vec,
+ FILTER_BITS - ROUND0_BITS,
conv_params->use_dist_wtd_comp_avg, &d01, &d23);
store_u8_4x1(dst8 + 0 * dst8_stride, d01, 0);