Optimize Neon implementation of cdef filter functions
Optimize the constraint function, make the computation of max and min
more parallel and simplify the computation of the final result.
Change-Id: I6e1a7993285e74165929dfc8185cb4f397f8478b
diff --git a/av1/common/arm/cdef_block_neon.c b/av1/common/arm/cdef_block_neon.c
index 68a292b..69ea49f 100644
--- a/av1/common/arm/cdef_block_neon.c
+++ b/av1/common/arm/cdef_block_neon.c
@@ -450,36 +450,37 @@
// sign(a-b) * min(abs(a-b), max(0, threshold - (abs(a-b) >> adjdamp)))
static INLINE int16x8_t constrain16(uint16x8_t a, uint16x8_t b,
unsigned int threshold, int adjdamp) {
- int16x8_t diff = vreinterpretq_s16_u16(vsubq_u16(a, b));
- const int16x8_t sign = vshrq_n_s16(diff, 15);
- diff = vabsq_s16(diff);
- const uint16x8_t s =
- vqsubq_u16(vdupq_n_u16(threshold),
- vreinterpretq_u16_s16(vshlq_s16(diff, vdupq_n_s16(-adjdamp))));
- return veorq_s16(vaddq_s16(sign, vminq_s16(diff, vreinterpretq_s16_u16(s))),
- sign);
+ uint16x8_t diff = vabdq_u16(a, b);
+ const uint16x8_t a_gt_b = vcgtq_u16(a, b);
+ const uint16x8_t s = vqsubq_u16(vdupq_n_u16(threshold),
+ vshlq_u16(diff, vdupq_n_s16(-adjdamp)));
+ const int16x8_t clip = vreinterpretq_s16_u16(vminq_u16(diff, s));
+ return vbslq_s16(a_gt_b, clip, vnegq_s16(clip));
}
static INLINE uint16x8_t get_max_primary(const int is_lowbd, uint16x8_t *tap,
uint16x8_t max,
uint16x8_t cdef_large_value_mask) {
if (is_lowbd) {
- uint8x16_t max_u8 = vreinterpretq_u8_u16(tap[0]);
- max_u8 = vmaxq_u8(max_u8, vreinterpretq_u8_u16(tap[1]));
- max_u8 = vmaxq_u8(max_u8, vreinterpretq_u8_u16(tap[2]));
- max_u8 = vmaxq_u8(max_u8, vreinterpretq_u8_u16(tap[3]));
- /* The source is 16 bits, however, we only really care about the lower
- 8 bits. The upper 8 bits contain the "large" flag. After the final
- primary max has been calculated, zero out the upper 8 bits. Use this
- to find the "16 bit" max. */
+ // The source is 16 bits, however, we only really care about the lower
+ // 8 bits. The upper 8 bits contain the "large" flag. After the final
+ // primary max has been calculated, zero out the upper 8 bits. Use this
+ // to find the "16 bit" max.
+ uint8x16_t max0 =
+ vmaxq_u8(vreinterpretq_u8_u16(tap[0]), vreinterpretq_u8_u16(tap[1]));
+ uint8x16_t max1 =
+ vmaxq_u8(vreinterpretq_u8_u16(tap[2]), vreinterpretq_u8_u16(tap[3]));
+ max0 = vmaxq_u8(max0, max1);
max = vmaxq_u16(
- max, vandq_u16(vreinterpretq_u16_u8(max_u8), cdef_large_value_mask));
+ max, vandq_u16(vreinterpretq_u16_u8(max0), cdef_large_value_mask));
} else {
- /* Convert CDEF_VERY_LARGE to 0 before calculating max. */
- max = vmaxq_u16(max, vandq_u16(tap[0], cdef_large_value_mask));
- max = vmaxq_u16(max, vandq_u16(tap[1], cdef_large_value_mask));
- max = vmaxq_u16(max, vandq_u16(tap[2], cdef_large_value_mask));
- max = vmaxq_u16(max, vandq_u16(tap[3], cdef_large_value_mask));
+ // Convert CDEF_VERY_LARGE to 0 before calculating max.
+ uint16x8_t max0 = vmaxq_u16(vandq_u16(tap[0], cdef_large_value_mask),
+ vandq_u16(tap[1], cdef_large_value_mask));
+ uint16x8_t max1 = vmaxq_u16(vandq_u16(tap[2], cdef_large_value_mask),
+ vandq_u16(tap[3], cdef_large_value_mask));
+ max0 = vmaxq_u16(max0, max1);
+ max = vmaxq_u16(max0, max);
}
return max;
}
@@ -488,30 +489,37 @@
uint16x8_t max,
uint16x8_t cdef_large_value_mask) {
if (is_lowbd) {
- uint8x16_t max_u8 = vreinterpretq_u8_u16(tap[0]);
- max_u8 = vmaxq_u8(max_u8, vreinterpretq_u8_u16(tap[1]));
- max_u8 = vmaxq_u8(max_u8, vreinterpretq_u8_u16(tap[2]));
- max_u8 = vmaxq_u8(max_u8, vreinterpretq_u8_u16(tap[3]));
- max_u8 = vmaxq_u8(max_u8, vreinterpretq_u8_u16(tap[4]));
- max_u8 = vmaxq_u8(max_u8, vreinterpretq_u8_u16(tap[5]));
- max_u8 = vmaxq_u8(max_u8, vreinterpretq_u8_u16(tap[6]));
- max_u8 = vmaxq_u8(max_u8, vreinterpretq_u8_u16(tap[7]));
- /* The source is 16 bits, however, we only really care about the lower
- 8 bits. The upper 8 bits contain the "large" flag. After the final
- primary max has been calculated, zero out the upper 8 bits. Use this
- to find the "16 bit" max. */
+ // The source is 16 bits, however, we only really care about the lower
+ // 8 bits. The upper 8 bits contain the "large" flag. After the final
+ // primary max has been calculated, zero out the upper 8 bits. Use this
+ // to find the "16 bit" max.
+ uint8x16_t max0 =
+ vmaxq_u8(vreinterpretq_u8_u16(tap[0]), vreinterpretq_u8_u16(tap[1]));
+ uint8x16_t max1 =
+ vmaxq_u8(vreinterpretq_u8_u16(tap[2]), vreinterpretq_u8_u16(tap[3]));
+ uint8x16_t max2 =
+ vmaxq_u8(vreinterpretq_u8_u16(tap[4]), vreinterpretq_u8_u16(tap[5]));
+ uint8x16_t max3 =
+ vmaxq_u8(vreinterpretq_u8_u16(tap[6]), vreinterpretq_u8_u16(tap[7]));
+ max0 = vmaxq_u8(max0, max1);
+ max2 = vmaxq_u8(max2, max3);
+ max0 = vmaxq_u8(max0, max2);
max = vmaxq_u16(
- max, vandq_u16(vreinterpretq_u16_u8(max_u8), cdef_large_value_mask));
+ max, vandq_u16(vreinterpretq_u16_u8(max0), cdef_large_value_mask));
} else {
- /* Convert CDEF_VERY_LARGE to 0 before calculating max. */
- max = vmaxq_u16(max, vandq_u16(tap[0], cdef_large_value_mask));
- max = vmaxq_u16(max, vandq_u16(tap[1], cdef_large_value_mask));
- max = vmaxq_u16(max, vandq_u16(tap[2], cdef_large_value_mask));
- max = vmaxq_u16(max, vandq_u16(tap[3], cdef_large_value_mask));
- max = vmaxq_u16(max, vandq_u16(tap[4], cdef_large_value_mask));
- max = vmaxq_u16(max, vandq_u16(tap[5], cdef_large_value_mask));
- max = vmaxq_u16(max, vandq_u16(tap[6], cdef_large_value_mask));
- max = vmaxq_u16(max, vandq_u16(tap[7], cdef_large_value_mask));
+ // Convert CDEF_VERY_LARGE to 0 before calculating max.
+ uint16x8_t max0 = vmaxq_u16(vandq_u16(tap[0], cdef_large_value_mask),
+ vandq_u16(tap[1], cdef_large_value_mask));
+ uint16x8_t max1 = vmaxq_u16(vandq_u16(tap[2], cdef_large_value_mask),
+ vandq_u16(tap[3], cdef_large_value_mask));
+ uint16x8_t max2 = vmaxq_u16(vandq_u16(tap[4], cdef_large_value_mask),
+ vandq_u16(tap[5], cdef_large_value_mask));
+ uint16x8_t max3 = vmaxq_u16(vandq_u16(tap[6], cdef_large_value_mask),
+ vandq_u16(tap[7], cdef_large_value_mask));
+ max0 = vmaxq_u16(max0, max1);
+ max2 = vmaxq_u16(max2, max3);
+ max0 = vmaxq_u16(max0, max2);
+ max = vmaxq_u16(max, max0);
}
return max;
}
@@ -576,10 +584,10 @@
if (clipping_required) {
max = get_max_primary(is_lowbd, tap, max, cdef_large_value_mask);
- min = vminq_u16(min, tap[0]);
- min = vminq_u16(min, tap[1]);
- min = vminq_u16(min, tap[2]);
- min = vminq_u16(min, tap[3]);
+ uint16x8_t min1 = vminq_u16(tap[0], tap[1]);
+ uint16x8_t min2 = vminq_u16(tap[2], tap[3]);
+ min1 = vminq_u16(min1, min2);
+ min = vminq_u16(min, min1);
}
}
@@ -621,22 +629,20 @@
if (clipping_required) {
max = get_max_secondary(is_lowbd, tap, max, cdef_large_value_mask);
- min = vminq_u16(min, tap[0]);
- min = vminq_u16(min, tap[1]);
- min = vminq_u16(min, tap[2]);
- min = vminq_u16(min, tap[3]);
- min = vminq_u16(min, tap[4]);
- min = vminq_u16(min, tap[5]);
- min = vminq_u16(min, tap[6]);
- min = vminq_u16(min, tap[7]);
+ uint16x8_t min0 = vminq_u16(tap[0], tap[1]);
+ uint16x8_t min1 = vminq_u16(tap[2], tap[3]);
+ uint16x8_t min2 = vminq_u16(tap[4], tap[5]);
+ uint16x8_t min3 = vminq_u16(tap[6], tap[7]);
+ min0 = vminq_u16(min0, min1);
+ min2 = vminq_u16(min2, min3);
+ min0 = vminq_u16(min0, min2);
+ min = vminq_u16(min, min0);
}
}
// res = row + ((sum - (sum < 0) + 8) >> 4)
sum = vaddq_s16(sum, vreinterpretq_s16_u16(vcltq_s16(sum, vdupq_n_s16(0))));
- int16x8_t res = vaddq_s16(sum, vdupq_n_s16(8));
- res = vshrq_n_s16(res, 4);
- res = vaddq_s16(vreinterpretq_s16_u16(s), res);
+ int16x8_t res = vrsraq_n_s16(vreinterpretq_s16_u16(s), sum, 4);
if (clipping_required) {
res = vminq_s16(vmaxq_s16(res, vreinterpretq_s16_u16(min)),
@@ -716,10 +722,10 @@
if (clipping_required) {
max = get_max_primary(is_lowbd, tap, max, cdef_large_value_mask);
- min = vminq_u16(min, tap[0]);
- min = vminq_u16(min, tap[1]);
- min = vminq_u16(min, tap[2]);
- min = vminq_u16(min, tap[3]);
+ uint16x8_t min0 = vminq_u16(tap[0], tap[1]);
+ uint16x8_t min1 = vminq_u16(tap[2], tap[3]);
+ min0 = vminq_u16(min0, min1);
+ min = vminq_u16(min, min0);
}
}
@@ -761,22 +767,21 @@
if (clipping_required) {
max = get_max_secondary(is_lowbd, tap, max, cdef_large_value_mask);
- min = vminq_u16(min, tap[0]);
- min = vminq_u16(min, tap[1]);
- min = vminq_u16(min, tap[2]);
- min = vminq_u16(min, tap[3]);
- min = vminq_u16(min, tap[4]);
- min = vminq_u16(min, tap[5]);
- min = vminq_u16(min, tap[6]);
- min = vminq_u16(min, tap[7]);
+ uint16x8_t min0 = vminq_u16(tap[0], tap[1]);
+ uint16x8_t min1 = vminq_u16(tap[2], tap[3]);
+ uint16x8_t min2 = vminq_u16(tap[4], tap[5]);
+ uint16x8_t min3 = vminq_u16(tap[6], tap[7]);
+ min0 = vminq_u16(min0, min1);
+ min2 = vminq_u16(min2, min3);
+ min0 = vminq_u16(min0, min2);
+ min = vminq_u16(min, min0);
}
}
// res = row + ((sum - (sum < 0) + 8) >> 4)
sum = vaddq_s16(sum, vreinterpretq_s16_u16(vcltq_s16(sum, vdupq_n_s16(0))));
- int16x8_t res = vaddq_s16(sum, vdupq_n_s16(8));
- res = vshrq_n_s16(res, 4);
- res = vaddq_s16(vreinterpretq_s16_u16(s), res);
+ int16x8_t res = vrsraq_n_s16(vreinterpretq_s16_u16(s), sum, 4);
+
if (clipping_required) {
res = vminq_s16(vmaxq_s16(res, vreinterpretq_s16_u16(min)),
vreinterpretq_s16_u16(max));