Specialize av1_dist_wtd_convolve_y_neon for 6-tap filter
av1_dist_wtd_convolve_y_neon can be called with 4-, 6- and 8-tap
filters, with 4- and 6-tap filters being padded out with 0s to enable
re-use of the same 8-tap code path. This is inefficient as we end up
spending a lot of time multiplying by, and adding, 0 - especially
since the most common filter size used with this function is 6-taps.
This patch adds a av1_dist_wtd_convolve_y_neon code path specialized
for 6-tap filters. This new path is used for both 4-tap and 6-tap
filters to reduce the amount of redundant work.
Change-Id: I730d6d1ff2e6316aedbc123c24bc26dc5a203dd1
diff --git a/aom_dsp/arm/mem_neon.h b/aom_dsp/arm/mem_neon.h
index 994a636..af11b1f 100644
--- a/aom_dsp/arm/mem_neon.h
+++ b/aom_dsp/arm/mem_neon.h
@@ -654,6 +654,14 @@
*tu1 = load_unaligned_u8_4x2(buf, stride);
}
+static INLINE void load_unaligned_u8_3x8(const uint8_t *buf, int stride,
+ uint8x8_t *tu0, uint8x8_t *tu1,
+ uint8x8_t *tu2) {
+ load_unaligned_u8_4x4(buf, stride, tu0, tu1);
+ buf += 4 * stride;
+ *tu2 = load_unaligned_u8_4x2(buf, stride);
+}
+
static INLINE void load_unaligned_u8_4x8(const uint8_t *buf, int stride,
uint8x8_t *tu0, uint8x8_t *tu1,
uint8x8_t *tu2, uint8x8_t *tu3) {
diff --git a/av1/common/arm/convolve_neon.c b/av1/common/arm/convolve_neon.c
index 8ee3203..f12bf07 100644
--- a/av1/common/arm/convolve_neon.c
+++ b/av1/common/arm/convolve_neon.c
@@ -1536,44 +1536,6 @@
#endif // defined(__aarch64__) && defined(__ARM_FEATURE_MATMUL_INT8)
-static INLINE int16x4_t convolve6_4x4(const int16x4_t s0, const int16x4_t s1,
- const int16x4_t s2, const int16x4_t s3,
- const int16x4_t s4, const int16x4_t s5,
- const int16x8_t y_filter_0_7) {
- const int16x4_t y_filter_0_3 = vget_low_s16(y_filter_0_7);
- const int16x4_t y_filter_4_7 = vget_high_s16(y_filter_0_7);
- int16x4_t sum;
-
- // Filter values at indices 0 and 7 are 0.
- sum = vmul_lane_s16(s0, y_filter_0_3, 1);
- sum = vmla_lane_s16(sum, s1, y_filter_0_3, 2);
- sum = vmla_lane_s16(sum, s2, y_filter_0_3, 3);
- sum = vmla_lane_s16(sum, s3, y_filter_4_7, 0);
- sum = vmla_lane_s16(sum, s4, y_filter_4_7, 1);
- sum = vmla_lane_s16(sum, s5, y_filter_4_7, 2);
-
- return sum;
-}
-
-static INLINE int16x8_t convolve6_8x4(const int16x8_t s0, const int16x8_t s1,
- const int16x8_t s2, const int16x8_t s3,
- const int16x8_t s4, const int16x8_t s5,
- const int16x8_t y_filters) {
- const int16x4_t y_filter_lo = vget_low_s16(y_filters);
- const int16x4_t y_filter_hi = vget_high_s16(y_filters);
- int16x8_t sum;
-
- // Filter values at indices 0 and 7 are 0.
- sum = vmulq_lane_s16(s0, y_filter_lo, 1);
- sum = vmlaq_lane_s16(sum, s1, y_filter_lo, 2);
- sum = vmlaq_lane_s16(sum, s2, y_filter_lo, 3);
- sum = vmlaq_lane_s16(sum, s3, y_filter_hi, 0);
- sum = vmlaq_lane_s16(sum, s4, y_filter_hi, 1);
- sum = vmlaq_lane_s16(sum, s5, y_filter_hi, 2);
-
- return sum;
-}
-
static INLINE void convolve_y_sr_6tap_neon(const uint8_t *src_ptr,
int src_stride, uint8_t *dst_ptr,
const int dst_stride, int w, int h,
diff --git a/av1/common/arm/convolve_neon.h b/av1/common/arm/convolve_neon.h
index 3f10661..b8eac71 100644
--- a/av1/common/arm/convolve_neon.h
+++ b/av1/common/arm/convolve_neon.h
@@ -382,6 +382,44 @@
return sum;
}
+static INLINE int16x4_t convolve6_4x4(const int16x4_t s0, const int16x4_t s1,
+ const int16x4_t s2, const int16x4_t s3,
+ const int16x4_t s4, const int16x4_t s5,
+ const int16x8_t y_filter_0_7) {
+ const int16x4_t y_filter_0_3 = vget_low_s16(y_filter_0_7);
+ const int16x4_t y_filter_4_7 = vget_high_s16(y_filter_0_7);
+ int16x4_t sum;
+
+ // Filter values at indices 0 and 7 are 0.
+ sum = vmul_lane_s16(s0, y_filter_0_3, 1);
+ sum = vmla_lane_s16(sum, s1, y_filter_0_3, 2);
+ sum = vmla_lane_s16(sum, s2, y_filter_0_3, 3);
+ sum = vmla_lane_s16(sum, s3, y_filter_4_7, 0);
+ sum = vmla_lane_s16(sum, s4, y_filter_4_7, 1);
+ sum = vmla_lane_s16(sum, s5, y_filter_4_7, 2);
+
+ return sum;
+}
+
+static INLINE int16x8_t convolve6_8x4(const int16x8_t s0, const int16x8_t s1,
+ const int16x8_t s2, const int16x8_t s3,
+ const int16x8_t s4, const int16x8_t s5,
+ const int16x8_t y_filters) {
+ const int16x4_t y_filter_lo = vget_low_s16(y_filters);
+ const int16x4_t y_filter_hi = vget_high_s16(y_filters);
+ int16x8_t sum;
+
+ // Filter values at indices 0 and 7 are 0.
+ sum = vmulq_lane_s16(s0, y_filter_lo, 1);
+ sum = vmlaq_lane_s16(sum, s1, y_filter_lo, 2);
+ sum = vmlaq_lane_s16(sum, s2, y_filter_lo, 3);
+ sum = vmlaq_lane_s16(sum, s3, y_filter_hi, 0);
+ sum = vmlaq_lane_s16(sum, s4, y_filter_hi, 1);
+ sum = vmlaq_lane_s16(sum, s5, y_filter_hi, 2);
+
+ return sum;
+}
+
static INLINE uint16x4_t convolve6_4_s32(const int16x4_t s0, const int16x4_t s1,
const int16x4_t s2, const int16x4_t s3,
const int16x4_t s4, const int16x4_t s5,
diff --git a/av1/common/arm/jnt_convolve_neon.c b/av1/common/arm/jnt_convolve_neon.c
index 6c1e38e..ae437da 100644
--- a/av1/common/arm/jnt_convolve_neon.c
+++ b/av1/common/arm/jnt_convolve_neon.c
@@ -2053,6 +2053,306 @@
return vqrshlq_s16(sum, shift_round_0);
}
+void dist_wtd_convolve_y_6tap_neon(const uint8_t *src_ptr, int src_stride,
+ uint8_t *dst8_ptr, const int dst8_stride,
+ int w, int h, const int16x8_t y_filter,
+ ConvolveParams *conv_params) {
+ CONV_BUF_TYPE *dst_ptr = conv_params->dst;
+ const int dst_stride = conv_params->dst_stride;
+ const int bits = FILTER_BITS - conv_params->round_0;
+ const int bd = 8;
+ const int offset_bits = bd + 2 * FILTER_BITS - conv_params->round_0;
+ const int round_offset = (1 << (offset_bits - conv_params->round_1)) +
+ (1 << (offset_bits - conv_params->round_1 - 1));
+ const int round_bits =
+ 2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
+ const uint16_t fwd_offset = conv_params->fwd_offset;
+ const uint16_t bck_offset = conv_params->bck_offset;
+ const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
+ const int shift_value = (conv_params->round_1 - 1 - bits);
+
+ // used to get rid of multiplication = (vertical filter output sum) *
+ // (1<<bits).
+ assert((conv_params->round_1 - 2) >= bits);
+
+ if (w <= 4 || h == 4) {
+ int16x4_t s0, s1, s2, s3, s4, s5, d0;
+ uint16x4_t dd0;
+ uint8x8_t t0 = vdup_n_u8(0);
+ uint8x8_t t1 = vdup_n_u8(0);
+ uint8x8_t t2 = vdup_n_u8(0);
+ int16x8_t tt0, tt1, tt2;
+ uint8x8_t d01;
+#if defined(__aarch64__)
+ int16x4_t s6, s7, s8, d1, d2, d3;
+ uint16x4_t dd1, dd2, dd3;
+ uint8x8_t d23;
+#endif
+
+ const int16x4_t round_offset64 = vdup_n_s16(round_offset);
+ const int16x4_t shift_vec = vdup_n_s16(-shift_value);
+ int width = w;
+
+ do {
+ const uint8_t *s = src_ptr;
+ CONV_BUF_TYPE *d = dst_ptr;
+ uint8_t *d_u8 = dst8_ptr;
+ int height = h;
+
+ load_unaligned_u8_3x8(s, src_stride, &t0, &t1, &t2);
+
+ tt0 = vreinterpretq_s16_u16(vmovl_u8(t0));
+ tt1 = vreinterpretq_s16_u16(vmovl_u8(t1));
+ tt2 = vreinterpretq_s16_u16(vmovl_u8(t2));
+
+ s0 = vget_low_s16(tt0);
+ s1 = vget_high_s16(tt0);
+ s2 = vget_low_s16(tt1);
+ s3 = vget_high_s16(tt1);
+ s4 = vget_low_s16(tt2);
+
+ s += 5 * src_stride;
+ do {
+#if defined(__aarch64__)
+ load_unaligned_u8_4x4(s, src_stride, &t0, &t1);
+
+ tt0 = vreinterpretq_s16_u16(vmovl_u8(t0));
+ tt1 = vreinterpretq_s16_u16(vmovl_u8(t1));
+
+ s5 = vget_low_s16(tt0);
+ s6 = vget_high_s16(tt0);
+ s7 = vget_low_s16(tt1);
+ s8 = vget_high_s16(tt1);
+
+ d0 = convolve6_4x4(s0, s1, s2, s3, s4, s5, y_filter);
+ d1 = convolve6_4x4(s1, s2, s3, s4, s5, s6, y_filter);
+ d2 = convolve6_4x4(s2, s3, s4, s5, s6, s7, y_filter);
+ d3 = convolve6_4x4(s3, s4, s5, s6, s7, s8, y_filter);
+
+ d0 = vqrshl_s16(d0, shift_vec);
+ d1 = vqrshl_s16(d1, shift_vec);
+ d2 = vqrshl_s16(d2, shift_vec);
+ d3 = vqrshl_s16(d3, shift_vec);
+
+ d0 = vadd_s16(d0, round_offset64);
+ d1 = vadd_s16(d1, round_offset64);
+ d2 = vadd_s16(d2, round_offset64);
+ d3 = vadd_s16(d3, round_offset64);
+
+ if (conv_params->do_average) {
+ load_u16_4x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
+
+ compute_avg_4x4(dd0, dd1, dd2, dd3, vreinterpret_u16_s16(d0),
+ vreinterpret_u16_s16(d1), vreinterpret_u16_s16(d2),
+ vreinterpret_u16_s16(d3), fwd_offset, bck_offset,
+ round_offset64, round_bits, use_dist_wtd_comp_avg,
+ &d01, &d23);
+
+ store_u8_4x1(d_u8 + 0 * dst8_stride, d01, 0);
+ store_u8_4x1(d_u8 + 1 * dst8_stride, d01, 1);
+ store_u8_4x1(d_u8 + 2 * dst8_stride, d23, 0);
+ store_u8_4x1(d_u8 + 3 * dst8_stride, d23, 1);
+ } else {
+ store_u16_4x4(d, dst_stride, vreinterpret_u16_s16(d0),
+ vreinterpret_u16_s16(d1), vreinterpret_u16_s16(d2),
+ vreinterpret_u16_s16(d3));
+ }
+
+ s0 = s4;
+ s1 = s5;
+ s2 = s6;
+ s3 = s7;
+ s4 = s8;
+ s += 4 * src_stride;
+ d += 4 * dst_stride;
+ d_u8 += 4 * dst8_stride;
+ height -= 4;
+#else // !defined(__aarch64__)
+ t0 = load_unaligned_u8_4x1(s);
+ tt0 = vreinterpretq_s16_u16(vmovl_u8(t0));
+ s5 = vget_low_s16(tt0);
+
+ d0 = convolve6_4x4(s0, s1, s2, s3, s4, s5, y_filter);
+ d0 = vqrshl_s16(d0, shift_vec);
+ d0 = vadd_s16(d0, round_offset64);
+
+ if (conv_params->do_average) {
+ dd0 = vld1_u16(d);
+
+ compute_avg_4x1(dd0, vreinterpret_u16_s16(d0), fwd_offset, bck_offset,
+ round_offset64, round_bits, use_dist_wtd_comp_avg,
+ &d01);
+
+ store_u8_4x1(d_u8, d01, 0);
+ } else {
+ vst1_u16(d, vreinterpret_u16_s16(d0));
+ }
+
+ s0 = s1;
+ s1 = s2;
+ s2 = s3;
+ s3 = s4;
+ s4 = s5;
+ s += src_stride;
+ d += dst_stride;
+ d_u8 += dst8_stride;
+ height--;
+#endif // defined(__aarch64__)
+ } while (height > 0);
+ src_ptr += 4;
+ dst_ptr += 4;
+ dst8_ptr += 4;
+ width -= 4;
+ } while (width > 0);
+ } else {
+ int16x8_t s0, s1, s2, s3, s4, s5, d0;
+ uint16x8_t d8;
+ uint8x8_t t0, t1, t2, t3, t4;
+
+ const int16x8_t round_offset128 = vdupq_n_s16(round_offset);
+ const int16x8_t shift_vec = vdupq_n_s16(-shift_value);
+ const int16x4_t round_offset64 = vdup_n_s16(round_offset);
+#if defined(__aarch64__)
+ int16x8_t s6, s7, s8, s9, s10, s11, s12, d1, d2, d3, d4, d5, d6, d7;
+ uint16x8_t d9, d10, d11;
+ uint8x8_t t5, t6, t7;
+#endif
+ int width = w;
+
+ do {
+ const uint8_t *s = src_ptr + (5 * src_stride);
+ CONV_BUF_TYPE *d = dst_ptr;
+ uint8_t *d_u8 = dst8_ptr;
+ int height = h;
+
+ load_u8_8x5(src_ptr, src_stride, &t0, &t1, &t2, &t3, &t4);
+
+ s0 = vreinterpretq_s16_u16(vmovl_u8(t0));
+ s1 = vreinterpretq_s16_u16(vmovl_u8(t1));
+ s2 = vreinterpretq_s16_u16(vmovl_u8(t2));
+ s3 = vreinterpretq_s16_u16(vmovl_u8(t3));
+ s4 = vreinterpretq_s16_u16(vmovl_u8(t4));
+
+ do {
+#if defined(__aarch64__)
+ load_u8_8x8(s, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7);
+
+ s5 = vreinterpretq_s16_u16(vmovl_u8(t0));
+ s6 = vreinterpretq_s16_u16(vmovl_u8(t1));
+ s7 = vreinterpretq_s16_u16(vmovl_u8(t2));
+ s8 = vreinterpretq_s16_u16(vmovl_u8(t3));
+ s9 = vreinterpretq_s16_u16(vmovl_u8(t4));
+ s10 = vreinterpretq_s16_u16(vmovl_u8(t5));
+ s11 = vreinterpretq_s16_u16(vmovl_u8(t6));
+ s12 = vreinterpretq_s16_u16(vmovl_u8(t7));
+
+ d0 = convolve6_8x4(s0, s1, s2, s3, s4, s5, y_filter);
+ d1 = convolve6_8x4(s1, s2, s3, s4, s5, s6, y_filter);
+ d2 = convolve6_8x4(s2, s3, s4, s5, s6, s7, y_filter);
+ d3 = convolve6_8x4(s3, s4, s5, s6, s7, s8, y_filter);
+ d4 = convolve6_8x4(s4, s5, s6, s7, s8, s9, y_filter);
+ d5 = convolve6_8x4(s5, s6, s7, s8, s9, s10, y_filter);
+ d6 = convolve6_8x4(s6, s7, s8, s9, s10, s11, y_filter);
+ d7 = convolve6_8x4(s7, s8, s9, s10, s11, s12, y_filter);
+
+ d0 = vqrshlq_s16(d0, shift_vec);
+ d1 = vqrshlq_s16(d1, shift_vec);
+ d2 = vqrshlq_s16(d2, shift_vec);
+ d3 = vqrshlq_s16(d3, shift_vec);
+ d4 = vqrshlq_s16(d4, shift_vec);
+ d5 = vqrshlq_s16(d5, shift_vec);
+ d6 = vqrshlq_s16(d6, shift_vec);
+ d7 = vqrshlq_s16(d7, shift_vec);
+
+ d0 = vaddq_s16(d0, round_offset128);
+ d1 = vaddq_s16(d1, round_offset128);
+ d2 = vaddq_s16(d2, round_offset128);
+ d3 = vaddq_s16(d3, round_offset128);
+ d4 = vaddq_s16(d4, round_offset128);
+ d5 = vaddq_s16(d5, round_offset128);
+ d6 = vaddq_s16(d6, round_offset128);
+ d7 = vaddq_s16(d7, round_offset128);
+
+ if (conv_params->do_average) {
+ load_u16_8x4(d, dst_stride, &d8, &d9, &d10, &d11);
+ d += 4 * dst_stride;
+
+ compute_avg_8x4(d8, d9, d10, d11, vreinterpretq_u16_s16(d0),
+ vreinterpretq_u16_s16(d1), vreinterpretq_u16_s16(d2),
+ vreinterpretq_u16_s16(d3), fwd_offset, bck_offset,
+ round_offset64, round_bits, use_dist_wtd_comp_avg,
+ &t0, &t1, &t2, &t3);
+
+ store_u8_8x4(d_u8, dst8_stride, t0, t1, t2, t3);
+ d_u8 += 4 * dst8_stride;
+
+ load_u16_8x4(d, dst_stride, &d8, &d9, &d10, &d11);
+ d += 4 * dst_stride;
+
+ compute_avg_8x4(d8, d9, d10, d11, vreinterpretq_u16_s16(d4),
+ vreinterpretq_u16_s16(d5), vreinterpretq_u16_s16(d6),
+ vreinterpretq_u16_s16(d7), fwd_offset, bck_offset,
+ round_offset64, round_bits, use_dist_wtd_comp_avg,
+ &t0, &t1, &t2, &t3);
+
+ store_u8_8x4(d_u8, dst8_stride, t0, t1, t2, t3);
+ d_u8 += 4 * dst8_stride;
+ } else {
+ store_u16_8x8(d, dst_stride, vreinterpretq_u16_s16(d0),
+ vreinterpretq_u16_s16(d1), vreinterpretq_u16_s16(d2),
+ vreinterpretq_u16_s16(d3), vreinterpretq_u16_s16(d4),
+ vreinterpretq_u16_s16(d5), vreinterpretq_u16_s16(d6),
+ vreinterpretq_u16_s16(d7));
+ d += 8 * dst_stride;
+ }
+
+ s0 = s8;
+ s1 = s9;
+ s2 = s10;
+ s3 = s11;
+ s4 = s12;
+ s += 8 * src_stride;
+ height -= 8;
+#else // !defined(__aarch64__)
+ s5 = vreinterpretq_s16_u16(vmovl_u8(vld1_u8(s)));
+
+ d0 = convolve6_8x4(s0, s1, s2, s3, s4, s5, y_filter);
+ d0 = vqrshlq_s16(d0, shift_vec);
+ d0 = vaddq_s16(d0, round_offset128);
+
+ s0 = s1;
+ s1 = s2;
+ s2 = s3;
+ s3 = s4;
+ s4 = s5;
+
+ if (conv_params->do_average) {
+ d8 = vld1q_u16(d);
+ d += dst_stride;
+
+ compute_avg_8x1(d8, vreinterpretq_u16_s16(d0), fwd_offset, bck_offset,
+ round_offset64, round_bits, use_dist_wtd_comp_avg,
+ &t0);
+
+ vst1_u8(d_u8, t0);
+ d_u8 += dst8_stride;
+ } else {
+ vst1q_u16(d, vreinterpretq_u16_s16(d0));
+ d += dst_stride;
+ }
+
+ s += src_stride;
+ height--;
+#endif // defined(__aarch64__)
+ } while (height > 0);
+ src_ptr += 8;
+ dst_ptr += 8;
+ dst8_ptr += 8;
+ width -= 8;
+ } while (width > 0);
+ }
+}
+
void av1_dist_wtd_convolve_y_neon(const uint8_t *src, int src_stride,
uint8_t *dst8, int dst8_stride, int w, int h,
const InterpFilterParams *filter_params_y,
@@ -2071,6 +2371,12 @@
const int vert_offset = filter_params_y->taps / 2 - 1;
const uint8_t *src_ptr = src - (vert_offset * src_stride);
+ if (get_filter_tap(filter_params_y, subpel_y_qn) <= 6) {
+ dist_wtd_convolve_y_6tap_neon(src_ptr + src_stride, src_stride, dst8,
+ dst8_stride, w, h, y_filter, conv_params);
+ return;
+ }
+
CONV_BUF_TYPE *dst_ptr = conv_params->dst;
const int dst_stride = conv_params->dst_stride;
const int bits = FILTER_BITS - conv_params->round_0;