Optimize av1_dr_prediction_z2 Neon implementation Optimize av1_dr_prediction_z2_neon by cleaning and simplifying the calculations. This optimization is a port from SVT-AV1. Change-Id: I38464487ff0fe4aeee2575ca9851944fc133b4e8
diff --git a/aom_dsp/arm/intrapred_neon.c b/aom_dsp/arm/intrapred_neon.c index 6ea58b9..be63638 100644 --- a/aom_dsp/arm/intrapred_neon.c +++ b/aom_dsp/arm/intrapred_neon.c
@@ -1497,19 +1497,19 @@ static AOM_FORCE_INLINE void dr_prediction_z2_4xH_above_neon( const uint8_t *above, int upsample_above, int dx, int base_x, int y, uint8x8_t *a0_x, uint8x8_t *a1_x, uint16x4_t *shift0) { - uint16x4_t r6 = vcreate_u16(0x00C0008000400000); - uint16x4_t ydx = vdup_n_u16(y * dx); + const int x = -y * dx; + if (upsample_above) { // Cannot use LD2 here since we only want to load eight bytes, but LD2 can // only load either 16 or 32. uint8x8_t v_tmp = vld1_u8(above + base_x); *a0_x = vuzp_u8(v_tmp, vdup_n_u8(0)).val[0]; *a1_x = vuzp_u8(v_tmp, vdup_n_u8(0)).val[1]; - *shift0 = vand_u16(vsub_u16(r6, ydx), vdup_n_u16(0x1f)); + *shift0 = vdup_n_u16(x & 0x1f); // ((x << upsample_above) & 0x3f) >> 1 } else { *a0_x = load_unaligned_u8_4x1(above + base_x); *a1_x = load_unaligned_u8_4x1(above + base_x + 1); - *shift0 = vand_u16(vhsub_u16(r6, ydx), vdup_n_u16(0x1f)); + *shift0 = vdup_n_u16((x & 0x3f) >> 1); } } @@ -1519,19 +1519,15 @@ #else const uint8_t *left, #endif - int upsample_left, int dy, int r, int min_base_y, int frac_bits_y, - uint16x4_t *a0_y, uint16x4_t *a1_y, uint16x4_t *shift1) { + int upsample_left, int dy, int r, int frac_bits_y, uint16x4_t *a0_y, + uint16x4_t *a1_y, uint16x4_t *shift1) { int16x4_t dy64 = vdup_n_s16(dy); int16x4_t v_1234 = vcreate_s16(0x0004000300020001); int16x4_t v_frac_bits_y = vdup_n_s16(-frac_bits_y); - int16x4_t min_base_y64 = vdup_n_s16(min_base_y); int16x4_t v_r6 = vdup_n_s16(r << 6); int16x4_t y_c64 = vmls_s16(v_r6, v_1234, dy64); int16x4_t base_y_c64 = vshl_s16(y_c64, v_frac_bits_y); - // Values in base_y_c64 range from -2 through 14 inclusive. - base_y_c64 = vmax_s16(base_y_c64, min_base_y64); - #if AOM_ARCH_AARCH64 uint8x8_t left_idx0 = vreinterpret_u8_s16(vadd_s16(base_y_c64, vdup_n_s16(2))); // [0, 16] @@ -1563,7 +1559,8 @@ #endif // AOM_ARCH_AARCH64 if (upsample_left) { - *shift1 = vand_u16(vreinterpret_u16_s16(y_c64), vdup_n_u16(0x1f)); + *shift1 = vand_u16(vreinterpret_u16_s16(y_c64), + vdup_n_u16(0x1f)); // ((y << upsample_left) & 0x3f) >> 1 } else { *shift1 = vand_u16(vshr_n_u16(vreinterpret_u16_s16(y_c64), 1), vdup_n_u16(0x1f)); @@ -1572,10 +1569,7 @@ static AOM_FORCE_INLINE uint8x8_t dr_prediction_z2_8xH_above_neon( const uint8_t *above, int upsample_above, int dx, int base_x, int y) { - uint16x8_t c1234 = vcombine_u16(vcreate_u16(0x0004000300020001), - vcreate_u16(0x0008000700060005)); - uint16x8_t ydx = vdupq_n_u16(y * dx); - uint16x8_t r6 = vshlq_n_u16(vextq_u16(c1234, vdupq_n_u16(0), 2), 6); + const int x = -y * dx; uint16x8_t shift0; uint8x8_t a0_x0; @@ -1584,18 +1578,18 @@ uint8x8x2_t v_tmp = vld2_u8(above + base_x); a0_x0 = v_tmp.val[0]; a1_x0 = v_tmp.val[1]; - shift0 = vandq_u16(vsubq_u16(r6, ydx), vdupq_n_u16(0x1f)); + shift0 = vdupq_n_u16(x & 0x1f); // ((x << upsample_above) & 0x3f) >> 1 } else { a0_x0 = vld1_u8(above + base_x); a1_x0 = vld1_u8(above + base_x + 1); - shift0 = vandq_u16(vhsubq_u16(r6, ydx), vdupq_n_u16(0x1f)); + shift0 = vdupq_n_u16((x & 0x3f) >> 1); } uint16x8_t diff0 = vsubl_u8(a1_x0, a0_x0); // a[x+1] - a[x] - uint16x8_t a32 = - vmlal_u8(vdupq_n_u16(16), a0_x0, vdup_n_u8(32)); // a[x] * 32 + 16 + uint16x8_t a32 = vshll_n_u8(a0_x0, 5); // a[x] * 32 uint16x8_t res = vmlaq_u16(a32, diff0, shift0); - return vshrn_n_u16(res, 5); + + return vrshrn_n_u16(res, 5); } static AOM_FORCE_INLINE uint8x8_t dr_prediction_z2_8xH_left_neon( @@ -1604,20 +1598,16 @@ #else const uint8_t *left, #endif - int upsample_left, int dy, int r, int min_base_y, int frac_bits_y) { + int upsample_left, int dy, int r, int frac_bits_y) { int16x8_t v_r6 = vdupq_n_s16(r << 6); int16x8_t dy128 = vdupq_n_s16(dy); int16x8_t v_frac_bits_y = vdupq_n_s16(-frac_bits_y); - int16x8_t min_base_y128 = vdupq_n_s16(min_base_y); uint16x8_t c1234 = vcombine_u16(vcreate_u16(0x0004000300020001), vcreate_u16(0x0008000700060005)); int16x8_t y_c128 = vmlsq_s16(v_r6, vreinterpretq_s16_u16(c1234), dy128); int16x8_t base_y_c128 = vshlq_s16(y_c128, v_frac_bits_y); - // Values in base_y_c128 range from -2 through 31 inclusive. - base_y_c128 = vmaxq_s16(base_y_c128, min_base_y128); - #if AOM_ARCH_AARCH64 uint8x16_t left_idx0 = vreinterpretq_u8_s16(vaddq_s16(base_y_c128, vdupq_n_s16(2))); // [0, 33] @@ -1635,46 +1625,39 @@ uint16x8_t shift1; if (upsample_left) { - shift1 = vandq_u16(vreinterpretq_u16_s16(y_c128), vdupq_n_u16(0x1f)); + shift1 = + vandq_u16(vreinterpretq_u16_s16(y_c128), + vdupq_n_u16(0x1f)); // ((y << upsample_left) & 0x3f) >> 1 } else { shift1 = vshrq_n_u16( vandq_u16(vreinterpretq_u16_s16(y_c128), vdupq_n_u16(0x3f)), 1); } uint16x8_t diff1 = vsubl_u8(a1_x1, a0_x1); - uint16x8_t a32 = vmlal_u8(vdupq_n_u16(16), a0_x1, vdup_n_u8(32)); + uint16x8_t a32 = vshll_n_u8(a0_x1, 5); uint16x8_t res = vmlaq_u16(a32, diff1, shift1); - return vshrn_n_u16(res, 5); + + return vrshrn_n_u16(res, 5); } static AOM_FORCE_INLINE uint8x16_t dr_prediction_z2_WxH_above_neon( const uint8_t *above, int dx, int base_x, int y, int j) { - uint16x8x2_t c0123 = { { vcombine_u16(vcreate_u16(0x0003000200010000), - vcreate_u16(0x0007000600050004)), - vcombine_u16(vcreate_u16(0x000B000A00090008), - vcreate_u16(0x000F000E000D000C)) } }; - uint16x8_t j256 = vdupq_n_u16(j); - uint16x8_t ydx = vdupq_n_u16((uint16_t)(y * dx)); - + const int x = -y * dx; const uint8x16_t a0_x128 = vld1q_u8(above + base_x + j); const uint8x16_t a1_x128 = vld1q_u8(above + base_x + j + 1); - uint16x8_t res6_0 = vshlq_n_u16(vaddq_u16(c0123.val[0], j256), 6); - uint16x8_t res6_1 = vshlq_n_u16(vaddq_u16(c0123.val[1], j256), 6); - uint16x8_t shift0 = - vshrq_n_u16(vandq_u16(vsubq_u16(res6_0, ydx), vdupq_n_u16(0x3f)), 1); - uint16x8_t shift1 = - vshrq_n_u16(vandq_u16(vsubq_u16(res6_1, ydx), vdupq_n_u16(0x3f)), 1); + const uint16x8_t shift = vdupq_n_u16((x & 0x3f) >> 1); + // a[x+1] - a[x] uint16x8_t diff0 = vsubl_u8(vget_low_u8(a1_x128), vget_low_u8(a0_x128)); uint16x8_t diff1 = vsubl_u8(vget_high_u8(a1_x128), vget_high_u8(a0_x128)); - // a[x] * 32 + 16 - uint16x8_t a32_0 = - vmlal_u8(vdupq_n_u16(16), vget_low_u8(a0_x128), vdup_n_u8(32)); - uint16x8_t a32_1 = - vmlal_u8(vdupq_n_u16(16), vget_high_u8(a0_x128), vdup_n_u8(32)); - uint16x8_t res0 = vmlaq_u16(a32_0, diff0, shift0); - uint16x8_t res1 = vmlaq_u16(a32_1, diff1, shift1); - return vcombine_u8(vshrn_n_u16(res0, 5), vshrn_n_u16(res1, 5)); + + // a[x] * 32 + uint16x8_t a32_0 = vshll_n_u8(vget_low_u8(a0_x128), 5); + uint16x8_t a32_1 = vshll_n_u8(vget_high_u8(a0_x128), 5); + uint16x8_t res0 = vmlaq_u16(a32_0, diff0, shift); + uint16x8_t res1 = vmlaq_u16(a32_1, diff1, shift); + + return vcombine_u8(vrshrn_n_u16(res0, 5), vrshrn_n_u16(res1, 5)); } static AOM_FORCE_INLINE uint8x16_t dr_prediction_z2_WxH_left_neon( @@ -1686,39 +1669,25 @@ int dy, int r, int j) { // here upsample_above and upsample_left are 0 by design of // av1_use_intra_edge_upsample - const int min_base_y = -1; - - int16x8_t min_base_y256 = vdupq_n_s16(min_base_y); - int16x8_t half_min_base_y256 = vdupq_n_s16(min_base_y >> 1); int16x8_t dy256 = vdupq_n_s16(dy); uint16x8_t j256 = vdupq_n_u16(j); - uint16x8x2_t c0123 = { { vcombine_u16(vcreate_u16(0x0003000200010000), - vcreate_u16(0x0007000600050004)), - vcombine_u16(vcreate_u16(0x000B000A00090008), - vcreate_u16(0x000F000E000D000C)) } }; - uint16x8x2_t c1234 = { { vaddq_u16(c0123.val[0], vdupq_n_u16(1)), - vaddq_u16(c0123.val[1], vdupq_n_u16(1)) } }; + uint16x8x2_t c1234 = { { vcombine_u16(vcreate_u16(0x0004000300020001), + vcreate_u16(0x0008000700060005)), + vcombine_u16(vcreate_u16(0x000C000B000A0009), + vcreate_u16(0x0010000F000E000D)) } }; int16x8_t v_r6 = vdupq_n_s16(r << 6); int16x8_t c256_0 = vreinterpretq_s16_u16(vaddq_u16(j256, c1234.val[0])); int16x8_t c256_1 = vreinterpretq_s16_u16(vaddq_u16(j256, c1234.val[1])); - int16x8_t mul16_lo = vreinterpretq_s16_u16( - vminq_u16(vreinterpretq_u16_s16(vmulq_s16(c256_0, dy256)), - vreinterpretq_u16_s16(half_min_base_y256))); - int16x8_t mul16_hi = vreinterpretq_s16_u16( - vminq_u16(vreinterpretq_u16_s16(vmulq_s16(c256_1, dy256)), - vreinterpretq_u16_s16(half_min_base_y256))); - int16x8_t y_c256_lo = vsubq_s16(v_r6, mul16_lo); - int16x8_t y_c256_hi = vsubq_s16(v_r6, mul16_hi); + + int16x8_t y_c256_lo = vmlsq_s16(v_r6, c256_0, dy256); + int16x8_t y_c256_hi = vmlsq_s16(v_r6, c256_1, dy256); int16x8_t base_y_c256_lo = vshrq_n_s16(y_c256_lo, 6); int16x8_t base_y_c256_hi = vshrq_n_s16(y_c256_hi, 6); - base_y_c256_lo = vmaxq_s16(min_base_y256, base_y_c256_lo); - base_y_c256_hi = vmaxq_s16(min_base_y256, base_y_c256_hi); - #if !AOM_ARCH_AARCH64 int16_t min_y = vgetq_lane_s16(base_y_c256_hi, 7); int16_t max_y = vgetq_lane_s16(base_y_c256_lo, 0); @@ -1799,13 +1768,13 @@ uint16x8_t diff_lo = vsubl_u8(a1_y0, a0_y0); uint16x8_t diff_hi = vsubl_u8(a1_y1, a0_y1); // a[x] * 32 + 16 - uint16x8_t a32_lo = vmlal_u8(vdupq_n_u16(16), a0_y0, vdup_n_u8(32)); - uint16x8_t a32_hi = vmlal_u8(vdupq_n_u16(16), a0_y1, vdup_n_u8(32)); + uint16x8_t a32_lo = vshll_n_u8(a0_y0, 5); + uint16x8_t a32_hi = vshll_n_u8(a0_y1, 5); uint16x8_t res0 = vmlaq_u16(a32_lo, diff_lo, shifty_lo); uint16x8_t res1 = vmlaq_u16(a32_hi, diff_hi, shifty_hi); - return vcombine_u8(vshrn_n_u16(res0, 5), vshrn_n_u16(res1, 5)); + return vcombine_u8(vrshrn_n_u16(res0, 5), vrshrn_n_u16(res1, 5)); } static void dr_prediction_z2_4xH_neon(int H, uint8_t *dst, ptrdiff_t stride, @@ -1813,17 +1782,12 @@ int upsample_above, int upsample_left, int dx, int dy) { const int min_base_x = -(1 << upsample_above); - const int min_base_y = -(1 << upsample_left); const int frac_bits_x = 6 - upsample_above; const int frac_bits_y = 6 - upsample_left; assert(dx > 0); - // pre-filter above pixels - // store in temp buffers: - // above[x] * 32 + 16 - // above[x+1] - above[x] - // final pixels will be calculated as: - // (above[x] * 32 + 16 + (above[x+1] - above[x]) * shift) >> 5 + // above0 * (32 - shift) + above1 * shift + // = (above1 - above0) * shift + above0 * 32 #if AOM_ARCH_AARCH64 // Use ext rather than loading left + 14 directly to avoid over-read. @@ -1852,11 +1816,10 @@ uint8x8_t a1_x = a1_x_u8; uint16x8_t diff = vsubl_u8(a1_x, a0_x); // a[x+1] - a[x] - uint16x8_t a32 = - vmlal_u8(vdupq_n_u16(16), a0_x, vdup_n_u8(32)); // a[x] * 32 + 16 + uint16x8_t a32 = vshll_n_u8(a0_x, 5); uint16x8_t res = vmlaq_u16(a32, diff, vcombine_u16(shift0, vdup_n_u16(0))); - uint8x8_t resx = vshrn_n_u16(res, 5); + uint8x8_t resx = vrshrn_n_u16(res, 5); vst1_lane_u32((uint32_t *)dst, vreinterpret_u32_u8(resx), 0); } else if (base_min_diff < 4) { uint8x8_t a0_x_u8, a1_x_u8; @@ -1869,17 +1832,16 @@ uint16x4_t a0_y; uint16x4_t a1_y; uint16x4_t shift1; - dr_prediction_z2_4xH_left_neon(LEFT, upsample_left, dy, r, min_base_y, - frac_bits_y, &a0_y, &a1_y, &shift1); + dr_prediction_z2_4xH_left_neon(LEFT, upsample_left, dy, r, frac_bits_y, + &a0_y, &a1_y, &shift1); a0_x = vcombine_u16(vget_low_u16(a0_x), a0_y); a1_x = vcombine_u16(vget_low_u16(a1_x), a1_y); uint16x8_t shift = vcombine_u16(shift0, shift1); uint16x8_t diff = vsubq_u16(a1_x, a0_x); // a[x+1] - a[x] - uint16x8_t a32 = - vmlaq_n_u16(vdupq_n_u16(16), a0_x, 32); // a[x] * 32 + 16 + uint16x8_t a32 = vshlq_n_u16(a0_x, 5); uint16x8_t res = vmlaq_u16(a32, diff, shift); - uint8x8_t resx = vshrn_n_u16(res, 5); + uint8x8_t resx = vrshrn_n_u16(res, 5); uint8x8_t resy = vext_u8(resx, vdup_n_u8(0), 4); uint8x8_t mask = vld1_u8(BaseMask[base_min_diff]); @@ -1888,12 +1850,12 @@ } else { uint16x4_t a0_y, a1_y; uint16x4_t shift1; - dr_prediction_z2_4xH_left_neon(LEFT, upsample_left, dy, r, min_base_y, - frac_bits_y, &a0_y, &a1_y, &shift1); - uint16x4_t diff = vsub_u16(a1_y, a0_y); // a[x+1] - a[x] - uint16x4_t a32 = vmla_n_u16(vdup_n_u16(16), a0_y, 32); // a[x] * 32 + 16 + dr_prediction_z2_4xH_left_neon(LEFT, upsample_left, dy, r, frac_bits_y, + &a0_y, &a1_y, &shift1); + uint16x4_t diff = vsub_u16(a1_y, a0_y); // a[x+1] - a[x] + uint16x4_t a32 = vshl_n_u16(a0_y, 5); uint16x4_t res = vmla_u16(a32, diff, shift1); - uint8x8_t resy = vshrn_n_u16(vcombine_u16(res, vdup_n_u16(0)), 5); + uint8x8_t resy = vrshrn_n_u16(vcombine_u16(res, vdup_n_u16(0)), 5); vst1_lane_u32((uint32_t *)dst, vreinterpret_u32_u8(resy), 0); } @@ -1908,16 +1870,11 @@ int upsample_above, int upsample_left, int dx, int dy) { const int min_base_x = -(1 << upsample_above); - const int min_base_y = -(1 << upsample_left); const int frac_bits_x = 6 - upsample_above; const int frac_bits_y = 6 - upsample_left; - // pre-filter above pixels - // store in temp buffers: - // above[x] * 32 + 16 - // above[x+1] - above[x] - // final pixels will be calculated as: - // (above[x] * 32 + 16 + (above[x+1] - above[x]) * shift) >> 5 + // above0 * (32 - shift) + above1 * shift + // = (above1 - above0) * shift + above0 * 32 #if AOM_ARCH_AARCH64 // Use ext rather than loading left + 30 directly to avoid over-read. @@ -1945,14 +1902,14 @@ } else if (base_min_diff < 8) { uint8x8_t resx = dr_prediction_z2_8xH_above_neon(above, upsample_above, dx, base_x, y); - uint8x8_t resy = dr_prediction_z2_8xH_left_neon( - LEFT, upsample_left, dy, r, min_base_y, frac_bits_y); + uint8x8_t resy = dr_prediction_z2_8xH_left_neon(LEFT, upsample_left, dy, + r, frac_bits_y); uint8x8_t mask = vld1_u8(BaseMask[base_min_diff]); uint8x8_t resxy = vbsl_u8(mask, resy, resx); vst1_u8(dst, resxy); } else { - uint8x8_t resy = dr_prediction_z2_8xH_left_neon( - LEFT, upsample_left, dy, r, min_base_y, frac_bits_y); + uint8x8_t resy = dr_prediction_z2_8xH_left_neon(LEFT, upsample_left, dy, + r, frac_bits_y); vst1_u8(dst, resy); }