Optimize av1_dr_prediction_z2 Neon implementation
Optimize av1_dr_prediction_z2_neon by cleaning and simplifying the
calculations.
This optimization is a port from SVT-AV1.
Change-Id: I38464487ff0fe4aeee2575ca9851944fc133b4e8
diff --git a/aom_dsp/arm/intrapred_neon.c b/aom_dsp/arm/intrapred_neon.c
index 6ea58b9..be63638 100644
--- a/aom_dsp/arm/intrapred_neon.c
+++ b/aom_dsp/arm/intrapred_neon.c
@@ -1497,19 +1497,19 @@
static AOM_FORCE_INLINE void dr_prediction_z2_4xH_above_neon(
const uint8_t *above, int upsample_above, int dx, int base_x, int y,
uint8x8_t *a0_x, uint8x8_t *a1_x, uint16x4_t *shift0) {
- uint16x4_t r6 = vcreate_u16(0x00C0008000400000);
- uint16x4_t ydx = vdup_n_u16(y * dx);
+ const int x = -y * dx;
+
if (upsample_above) {
// Cannot use LD2 here since we only want to load eight bytes, but LD2 can
// only load either 16 or 32.
uint8x8_t v_tmp = vld1_u8(above + base_x);
*a0_x = vuzp_u8(v_tmp, vdup_n_u8(0)).val[0];
*a1_x = vuzp_u8(v_tmp, vdup_n_u8(0)).val[1];
- *shift0 = vand_u16(vsub_u16(r6, ydx), vdup_n_u16(0x1f));
+ *shift0 = vdup_n_u16(x & 0x1f); // ((x << upsample_above) & 0x3f) >> 1
} else {
*a0_x = load_unaligned_u8_4x1(above + base_x);
*a1_x = load_unaligned_u8_4x1(above + base_x + 1);
- *shift0 = vand_u16(vhsub_u16(r6, ydx), vdup_n_u16(0x1f));
+ *shift0 = vdup_n_u16((x & 0x3f) >> 1);
}
}
@@ -1519,19 +1519,15 @@
#else
const uint8_t *left,
#endif
- int upsample_left, int dy, int r, int min_base_y, int frac_bits_y,
- uint16x4_t *a0_y, uint16x4_t *a1_y, uint16x4_t *shift1) {
+ int upsample_left, int dy, int r, int frac_bits_y, uint16x4_t *a0_y,
+ uint16x4_t *a1_y, uint16x4_t *shift1) {
int16x4_t dy64 = vdup_n_s16(dy);
int16x4_t v_1234 = vcreate_s16(0x0004000300020001);
int16x4_t v_frac_bits_y = vdup_n_s16(-frac_bits_y);
- int16x4_t min_base_y64 = vdup_n_s16(min_base_y);
int16x4_t v_r6 = vdup_n_s16(r << 6);
int16x4_t y_c64 = vmls_s16(v_r6, v_1234, dy64);
int16x4_t base_y_c64 = vshl_s16(y_c64, v_frac_bits_y);
- // Values in base_y_c64 range from -2 through 14 inclusive.
- base_y_c64 = vmax_s16(base_y_c64, min_base_y64);
-
#if AOM_ARCH_AARCH64
uint8x8_t left_idx0 =
vreinterpret_u8_s16(vadd_s16(base_y_c64, vdup_n_s16(2))); // [0, 16]
@@ -1563,7 +1559,8 @@
#endif // AOM_ARCH_AARCH64
if (upsample_left) {
- *shift1 = vand_u16(vreinterpret_u16_s16(y_c64), vdup_n_u16(0x1f));
+ *shift1 = vand_u16(vreinterpret_u16_s16(y_c64),
+ vdup_n_u16(0x1f)); // ((y << upsample_left) & 0x3f) >> 1
} else {
*shift1 =
vand_u16(vshr_n_u16(vreinterpret_u16_s16(y_c64), 1), vdup_n_u16(0x1f));
@@ -1572,10 +1569,7 @@
static AOM_FORCE_INLINE uint8x8_t dr_prediction_z2_8xH_above_neon(
const uint8_t *above, int upsample_above, int dx, int base_x, int y) {
- uint16x8_t c1234 = vcombine_u16(vcreate_u16(0x0004000300020001),
- vcreate_u16(0x0008000700060005));
- uint16x8_t ydx = vdupq_n_u16(y * dx);
- uint16x8_t r6 = vshlq_n_u16(vextq_u16(c1234, vdupq_n_u16(0), 2), 6);
+ const int x = -y * dx;
uint16x8_t shift0;
uint8x8_t a0_x0;
@@ -1584,18 +1578,18 @@
uint8x8x2_t v_tmp = vld2_u8(above + base_x);
a0_x0 = v_tmp.val[0];
a1_x0 = v_tmp.val[1];
- shift0 = vandq_u16(vsubq_u16(r6, ydx), vdupq_n_u16(0x1f));
+ shift0 = vdupq_n_u16(x & 0x1f); // ((x << upsample_above) & 0x3f) >> 1
} else {
a0_x0 = vld1_u8(above + base_x);
a1_x0 = vld1_u8(above + base_x + 1);
- shift0 = vandq_u16(vhsubq_u16(r6, ydx), vdupq_n_u16(0x1f));
+ shift0 = vdupq_n_u16((x & 0x3f) >> 1);
}
uint16x8_t diff0 = vsubl_u8(a1_x0, a0_x0); // a[x+1] - a[x]
- uint16x8_t a32 =
- vmlal_u8(vdupq_n_u16(16), a0_x0, vdup_n_u8(32)); // a[x] * 32 + 16
+ uint16x8_t a32 = vshll_n_u8(a0_x0, 5); // a[x] * 32
uint16x8_t res = vmlaq_u16(a32, diff0, shift0);
- return vshrn_n_u16(res, 5);
+
+ return vrshrn_n_u16(res, 5);
}
static AOM_FORCE_INLINE uint8x8_t dr_prediction_z2_8xH_left_neon(
@@ -1604,20 +1598,16 @@
#else
const uint8_t *left,
#endif
- int upsample_left, int dy, int r, int min_base_y, int frac_bits_y) {
+ int upsample_left, int dy, int r, int frac_bits_y) {
int16x8_t v_r6 = vdupq_n_s16(r << 6);
int16x8_t dy128 = vdupq_n_s16(dy);
int16x8_t v_frac_bits_y = vdupq_n_s16(-frac_bits_y);
- int16x8_t min_base_y128 = vdupq_n_s16(min_base_y);
uint16x8_t c1234 = vcombine_u16(vcreate_u16(0x0004000300020001),
vcreate_u16(0x0008000700060005));
int16x8_t y_c128 = vmlsq_s16(v_r6, vreinterpretq_s16_u16(c1234), dy128);
int16x8_t base_y_c128 = vshlq_s16(y_c128, v_frac_bits_y);
- // Values in base_y_c128 range from -2 through 31 inclusive.
- base_y_c128 = vmaxq_s16(base_y_c128, min_base_y128);
-
#if AOM_ARCH_AARCH64
uint8x16_t left_idx0 =
vreinterpretq_u8_s16(vaddq_s16(base_y_c128, vdupq_n_s16(2))); // [0, 33]
@@ -1635,46 +1625,39 @@
uint16x8_t shift1;
if (upsample_left) {
- shift1 = vandq_u16(vreinterpretq_u16_s16(y_c128), vdupq_n_u16(0x1f));
+ shift1 =
+ vandq_u16(vreinterpretq_u16_s16(y_c128),
+ vdupq_n_u16(0x1f)); // ((y << upsample_left) & 0x3f) >> 1
} else {
shift1 = vshrq_n_u16(
vandq_u16(vreinterpretq_u16_s16(y_c128), vdupq_n_u16(0x3f)), 1);
}
uint16x8_t diff1 = vsubl_u8(a1_x1, a0_x1);
- uint16x8_t a32 = vmlal_u8(vdupq_n_u16(16), a0_x1, vdup_n_u8(32));
+ uint16x8_t a32 = vshll_n_u8(a0_x1, 5);
uint16x8_t res = vmlaq_u16(a32, diff1, shift1);
- return vshrn_n_u16(res, 5);
+
+ return vrshrn_n_u16(res, 5);
}
static AOM_FORCE_INLINE uint8x16_t dr_prediction_z2_WxH_above_neon(
const uint8_t *above, int dx, int base_x, int y, int j) {
- uint16x8x2_t c0123 = { { vcombine_u16(vcreate_u16(0x0003000200010000),
- vcreate_u16(0x0007000600050004)),
- vcombine_u16(vcreate_u16(0x000B000A00090008),
- vcreate_u16(0x000F000E000D000C)) } };
- uint16x8_t j256 = vdupq_n_u16(j);
- uint16x8_t ydx = vdupq_n_u16((uint16_t)(y * dx));
-
+ const int x = -y * dx;
const uint8x16_t a0_x128 = vld1q_u8(above + base_x + j);
const uint8x16_t a1_x128 = vld1q_u8(above + base_x + j + 1);
- uint16x8_t res6_0 = vshlq_n_u16(vaddq_u16(c0123.val[0], j256), 6);
- uint16x8_t res6_1 = vshlq_n_u16(vaddq_u16(c0123.val[1], j256), 6);
- uint16x8_t shift0 =
- vshrq_n_u16(vandq_u16(vsubq_u16(res6_0, ydx), vdupq_n_u16(0x3f)), 1);
- uint16x8_t shift1 =
- vshrq_n_u16(vandq_u16(vsubq_u16(res6_1, ydx), vdupq_n_u16(0x3f)), 1);
+ const uint16x8_t shift = vdupq_n_u16((x & 0x3f) >> 1);
+
// a[x+1] - a[x]
uint16x8_t diff0 = vsubl_u8(vget_low_u8(a1_x128), vget_low_u8(a0_x128));
uint16x8_t diff1 = vsubl_u8(vget_high_u8(a1_x128), vget_high_u8(a0_x128));
- // a[x] * 32 + 16
- uint16x8_t a32_0 =
- vmlal_u8(vdupq_n_u16(16), vget_low_u8(a0_x128), vdup_n_u8(32));
- uint16x8_t a32_1 =
- vmlal_u8(vdupq_n_u16(16), vget_high_u8(a0_x128), vdup_n_u8(32));
- uint16x8_t res0 = vmlaq_u16(a32_0, diff0, shift0);
- uint16x8_t res1 = vmlaq_u16(a32_1, diff1, shift1);
- return vcombine_u8(vshrn_n_u16(res0, 5), vshrn_n_u16(res1, 5));
+
+ // a[x] * 32
+ uint16x8_t a32_0 = vshll_n_u8(vget_low_u8(a0_x128), 5);
+ uint16x8_t a32_1 = vshll_n_u8(vget_high_u8(a0_x128), 5);
+ uint16x8_t res0 = vmlaq_u16(a32_0, diff0, shift);
+ uint16x8_t res1 = vmlaq_u16(a32_1, diff1, shift);
+
+ return vcombine_u8(vrshrn_n_u16(res0, 5), vrshrn_n_u16(res1, 5));
}
static AOM_FORCE_INLINE uint8x16_t dr_prediction_z2_WxH_left_neon(
@@ -1686,39 +1669,25 @@
int dy, int r, int j) {
// here upsample_above and upsample_left are 0 by design of
// av1_use_intra_edge_upsample
- const int min_base_y = -1;
-
- int16x8_t min_base_y256 = vdupq_n_s16(min_base_y);
- int16x8_t half_min_base_y256 = vdupq_n_s16(min_base_y >> 1);
int16x8_t dy256 = vdupq_n_s16(dy);
uint16x8_t j256 = vdupq_n_u16(j);
- uint16x8x2_t c0123 = { { vcombine_u16(vcreate_u16(0x0003000200010000),
- vcreate_u16(0x0007000600050004)),
- vcombine_u16(vcreate_u16(0x000B000A00090008),
- vcreate_u16(0x000F000E000D000C)) } };
- uint16x8x2_t c1234 = { { vaddq_u16(c0123.val[0], vdupq_n_u16(1)),
- vaddq_u16(c0123.val[1], vdupq_n_u16(1)) } };
+ uint16x8x2_t c1234 = { { vcombine_u16(vcreate_u16(0x0004000300020001),
+ vcreate_u16(0x0008000700060005)),
+ vcombine_u16(vcreate_u16(0x000C000B000A0009),
+ vcreate_u16(0x0010000F000E000D)) } };
int16x8_t v_r6 = vdupq_n_s16(r << 6);
int16x8_t c256_0 = vreinterpretq_s16_u16(vaddq_u16(j256, c1234.val[0]));
int16x8_t c256_1 = vreinterpretq_s16_u16(vaddq_u16(j256, c1234.val[1]));
- int16x8_t mul16_lo = vreinterpretq_s16_u16(
- vminq_u16(vreinterpretq_u16_s16(vmulq_s16(c256_0, dy256)),
- vreinterpretq_u16_s16(half_min_base_y256)));
- int16x8_t mul16_hi = vreinterpretq_s16_u16(
- vminq_u16(vreinterpretq_u16_s16(vmulq_s16(c256_1, dy256)),
- vreinterpretq_u16_s16(half_min_base_y256)));
- int16x8_t y_c256_lo = vsubq_s16(v_r6, mul16_lo);
- int16x8_t y_c256_hi = vsubq_s16(v_r6, mul16_hi);
+
+ int16x8_t y_c256_lo = vmlsq_s16(v_r6, c256_0, dy256);
+ int16x8_t y_c256_hi = vmlsq_s16(v_r6, c256_1, dy256);
int16x8_t base_y_c256_lo = vshrq_n_s16(y_c256_lo, 6);
int16x8_t base_y_c256_hi = vshrq_n_s16(y_c256_hi, 6);
- base_y_c256_lo = vmaxq_s16(min_base_y256, base_y_c256_lo);
- base_y_c256_hi = vmaxq_s16(min_base_y256, base_y_c256_hi);
-
#if !AOM_ARCH_AARCH64
int16_t min_y = vgetq_lane_s16(base_y_c256_hi, 7);
int16_t max_y = vgetq_lane_s16(base_y_c256_lo, 0);
@@ -1799,13 +1768,13 @@
uint16x8_t diff_lo = vsubl_u8(a1_y0, a0_y0);
uint16x8_t diff_hi = vsubl_u8(a1_y1, a0_y1);
// a[x] * 32 + 16
- uint16x8_t a32_lo = vmlal_u8(vdupq_n_u16(16), a0_y0, vdup_n_u8(32));
- uint16x8_t a32_hi = vmlal_u8(vdupq_n_u16(16), a0_y1, vdup_n_u8(32));
+ uint16x8_t a32_lo = vshll_n_u8(a0_y0, 5);
+ uint16x8_t a32_hi = vshll_n_u8(a0_y1, 5);
uint16x8_t res0 = vmlaq_u16(a32_lo, diff_lo, shifty_lo);
uint16x8_t res1 = vmlaq_u16(a32_hi, diff_hi, shifty_hi);
- return vcombine_u8(vshrn_n_u16(res0, 5), vshrn_n_u16(res1, 5));
+ return vcombine_u8(vrshrn_n_u16(res0, 5), vrshrn_n_u16(res1, 5));
}
static void dr_prediction_z2_4xH_neon(int H, uint8_t *dst, ptrdiff_t stride,
@@ -1813,17 +1782,12 @@
int upsample_above, int upsample_left,
int dx, int dy) {
const int min_base_x = -(1 << upsample_above);
- const int min_base_y = -(1 << upsample_left);
const int frac_bits_x = 6 - upsample_above;
const int frac_bits_y = 6 - upsample_left;
assert(dx > 0);
- // pre-filter above pixels
- // store in temp buffers:
- // above[x] * 32 + 16
- // above[x+1] - above[x]
- // final pixels will be calculated as:
- // (above[x] * 32 + 16 + (above[x+1] - above[x]) * shift) >> 5
+ // above0 * (32 - shift) + above1 * shift
+ // = (above1 - above0) * shift + above0 * 32
#if AOM_ARCH_AARCH64
// Use ext rather than loading left + 14 directly to avoid over-read.
@@ -1852,11 +1816,10 @@
uint8x8_t a1_x = a1_x_u8;
uint16x8_t diff = vsubl_u8(a1_x, a0_x); // a[x+1] - a[x]
- uint16x8_t a32 =
- vmlal_u8(vdupq_n_u16(16), a0_x, vdup_n_u8(32)); // a[x] * 32 + 16
+ uint16x8_t a32 = vshll_n_u8(a0_x, 5);
uint16x8_t res =
vmlaq_u16(a32, diff, vcombine_u16(shift0, vdup_n_u16(0)));
- uint8x8_t resx = vshrn_n_u16(res, 5);
+ uint8x8_t resx = vrshrn_n_u16(res, 5);
vst1_lane_u32((uint32_t *)dst, vreinterpret_u32_u8(resx), 0);
} else if (base_min_diff < 4) {
uint8x8_t a0_x_u8, a1_x_u8;
@@ -1869,17 +1832,16 @@
uint16x4_t a0_y;
uint16x4_t a1_y;
uint16x4_t shift1;
- dr_prediction_z2_4xH_left_neon(LEFT, upsample_left, dy, r, min_base_y,
- frac_bits_y, &a0_y, &a1_y, &shift1);
+ dr_prediction_z2_4xH_left_neon(LEFT, upsample_left, dy, r, frac_bits_y,
+ &a0_y, &a1_y, &shift1);
a0_x = vcombine_u16(vget_low_u16(a0_x), a0_y);
a1_x = vcombine_u16(vget_low_u16(a1_x), a1_y);
uint16x8_t shift = vcombine_u16(shift0, shift1);
uint16x8_t diff = vsubq_u16(a1_x, a0_x); // a[x+1] - a[x]
- uint16x8_t a32 =
- vmlaq_n_u16(vdupq_n_u16(16), a0_x, 32); // a[x] * 32 + 16
+ uint16x8_t a32 = vshlq_n_u16(a0_x, 5);
uint16x8_t res = vmlaq_u16(a32, diff, shift);
- uint8x8_t resx = vshrn_n_u16(res, 5);
+ uint8x8_t resx = vrshrn_n_u16(res, 5);
uint8x8_t resy = vext_u8(resx, vdup_n_u8(0), 4);
uint8x8_t mask = vld1_u8(BaseMask[base_min_diff]);
@@ -1888,12 +1850,12 @@
} else {
uint16x4_t a0_y, a1_y;
uint16x4_t shift1;
- dr_prediction_z2_4xH_left_neon(LEFT, upsample_left, dy, r, min_base_y,
- frac_bits_y, &a0_y, &a1_y, &shift1);
- uint16x4_t diff = vsub_u16(a1_y, a0_y); // a[x+1] - a[x]
- uint16x4_t a32 = vmla_n_u16(vdup_n_u16(16), a0_y, 32); // a[x] * 32 + 16
+ dr_prediction_z2_4xH_left_neon(LEFT, upsample_left, dy, r, frac_bits_y,
+ &a0_y, &a1_y, &shift1);
+ uint16x4_t diff = vsub_u16(a1_y, a0_y); // a[x+1] - a[x]
+ uint16x4_t a32 = vshl_n_u16(a0_y, 5);
uint16x4_t res = vmla_u16(a32, diff, shift1);
- uint8x8_t resy = vshrn_n_u16(vcombine_u16(res, vdup_n_u16(0)), 5);
+ uint8x8_t resy = vrshrn_n_u16(vcombine_u16(res, vdup_n_u16(0)), 5);
vst1_lane_u32((uint32_t *)dst, vreinterpret_u32_u8(resy), 0);
}
@@ -1908,16 +1870,11 @@
int upsample_above, int upsample_left,
int dx, int dy) {
const int min_base_x = -(1 << upsample_above);
- const int min_base_y = -(1 << upsample_left);
const int frac_bits_x = 6 - upsample_above;
const int frac_bits_y = 6 - upsample_left;
- // pre-filter above pixels
- // store in temp buffers:
- // above[x] * 32 + 16
- // above[x+1] - above[x]
- // final pixels will be calculated as:
- // (above[x] * 32 + 16 + (above[x+1] - above[x]) * shift) >> 5
+ // above0 * (32 - shift) + above1 * shift
+ // = (above1 - above0) * shift + above0 * 32
#if AOM_ARCH_AARCH64
// Use ext rather than loading left + 30 directly to avoid over-read.
@@ -1945,14 +1902,14 @@
} else if (base_min_diff < 8) {
uint8x8_t resx =
dr_prediction_z2_8xH_above_neon(above, upsample_above, dx, base_x, y);
- uint8x8_t resy = dr_prediction_z2_8xH_left_neon(
- LEFT, upsample_left, dy, r, min_base_y, frac_bits_y);
+ uint8x8_t resy = dr_prediction_z2_8xH_left_neon(LEFT, upsample_left, dy,
+ r, frac_bits_y);
uint8x8_t mask = vld1_u8(BaseMask[base_min_diff]);
uint8x8_t resxy = vbsl_u8(mask, resy, resx);
vst1_u8(dst, resxy);
} else {
- uint8x8_t resy = dr_prediction_z2_8xH_left_neon(
- LEFT, upsample_left, dy, r, min_base_y, frac_bits_y);
+ uint8x8_t resy = dr_prediction_z2_8xH_left_neon(LEFT, upsample_left, dy,
+ r, frac_bits_y);
vst1_u8(dst, resy);
}