Refactor 8x8 16-bit Neon transpose functions Refactor the Neon implementation of transpose_s16_8x8(q) and transpose_u16_8x8 so that the final step compiles to 8 ZIP1/ZIP2 instructions as opposed to 8 EXT, MOV pairs. This change removes 8 instructions per call to transpose_s16_8x8(q), transpose_u16_8x8 where the result stays in registers for further processing - rather than being stored to memory - like in aom_hadamard_8x8_neon, for example. Co-authored-by: Jonathan Wright <jonathan.wright@arm.com> Change-Id: I470442d3392acf38c12817b87bdaa46eee887ff6
diff --git a/aom_dsp/arm/transpose_neon.h b/aom_dsp/arm/transpose_neon.h index 26fc1fd..68ec397 100644 --- a/aom_dsp/arm/transpose_neon.h +++ b/aom_dsp/arm/transpose_neon.h
@@ -258,13 +258,19 @@ a[3] = vreinterpretq_u16_u32(c1.val[1]); } -static INLINE uint16x8x2_t aom_vtrnq_u64_to_u16(const uint32x4_t a0, - const uint32x4_t a1) { +static INLINE uint16x8x2_t aom_vtrnq_u64_to_u16(uint32x4_t a0, uint32x4_t a1) { uint16x8x2_t b0; +#if defined(__aarch64__) + b0.val[0] = vreinterpretq_u16_u64( + vtrn1q_u64(vreinterpretq_u64_u32(a0), vreinterpretq_u64_u32(a1))); + b0.val[1] = vreinterpretq_u16_u64( + vtrn2q_u64(vreinterpretq_u64_u32(a0), vreinterpretq_u64_u32(a1))); +#else b0.val[0] = vcombine_u16(vreinterpret_u16_u32(vget_low_u32(a0)), vreinterpret_u16_u32(vget_low_u32(a1))); b0.val[1] = vcombine_u16(vreinterpret_u16_u32(vget_high_u32(a0)), vreinterpret_u16_u32(vget_high_u32(a1))); +#endif return b0; } @@ -514,25 +520,45 @@ const uint32x4x2_t c3 = vtrnq_u32(vreinterpretq_u32_u16(b2.val[1]), vreinterpretq_u32_u16(b3.val[1])); - *a0 = vcombine_u16(vget_low_u16(vreinterpretq_u16_u32(c0.val[0])), - vget_low_u16(vreinterpretq_u16_u32(c2.val[0]))); - *a4 = vcombine_u16(vget_high_u16(vreinterpretq_u16_u32(c0.val[0])), - vget_high_u16(vreinterpretq_u16_u32(c2.val[0]))); + // Swap 64 bit elements resulting in: + // d0.val[0]: 00 10 20 30 40 50 60 70 + // d0.val[1]: 04 14 24 34 44 54 64 74 + // d1.val[0]: 01 11 21 31 41 51 61 71 + // d1.val[1]: 05 15 25 35 45 55 65 75 + // d2.val[0]: 02 12 22 32 42 52 62 72 + // d2.val[1]: 06 16 26 36 46 56 66 76 + // d3.val[0]: 03 13 23 33 43 53 63 73 + // d3.val[1]: 07 17 27 37 47 57 67 77 - *a2 = vcombine_u16(vget_low_u16(vreinterpretq_u16_u32(c0.val[1])), - vget_low_u16(vreinterpretq_u16_u32(c2.val[1]))); - *a6 = vcombine_u16(vget_high_u16(vreinterpretq_u16_u32(c0.val[1])), - vget_high_u16(vreinterpretq_u16_u32(c2.val[1]))); + const uint16x8x2_t d0 = aom_vtrnq_u64_to_u16(c0.val[0], c2.val[0]); + const uint16x8x2_t d1 = aom_vtrnq_u64_to_u16(c1.val[0], c3.val[0]); + const uint16x8x2_t d2 = aom_vtrnq_u64_to_u16(c0.val[1], c2.val[1]); + const uint16x8x2_t d3 = aom_vtrnq_u64_to_u16(c1.val[1], c3.val[1]); - *a1 = vcombine_u16(vget_low_u16(vreinterpretq_u16_u32(c1.val[0])), - vget_low_u16(vreinterpretq_u16_u32(c3.val[0]))); - *a5 = vcombine_u16(vget_high_u16(vreinterpretq_u16_u32(c1.val[0])), - vget_high_u16(vreinterpretq_u16_u32(c3.val[0]))); + *a0 = d0.val[0]; + *a1 = d1.val[0]; + *a2 = d2.val[0]; + *a3 = d3.val[0]; + *a4 = d0.val[1]; + *a5 = d1.val[1]; + *a6 = d2.val[1]; + *a7 = d3.val[1]; +} - *a3 = vcombine_u16(vget_low_u16(vreinterpretq_u16_u32(c1.val[1])), - vget_low_u16(vreinterpretq_u16_u32(c3.val[1]))); - *a7 = vcombine_u16(vget_high_u16(vreinterpretq_u16_u32(c1.val[1])), - vget_high_u16(vreinterpretq_u16_u32(c3.val[1]))); +static INLINE int16x8x2_t aom_vtrnq_s64_to_s16(int32x4_t a0, int32x4_t a1) { + int16x8x2_t b0; +#if defined(__aarch64__) + b0.val[0] = vreinterpretq_s16_s64( + vtrn1q_s64(vreinterpretq_s64_s32(a0), vreinterpretq_s64_s32(a1))); + b0.val[1] = vreinterpretq_s16_s64( + vtrn2q_s64(vreinterpretq_s64_s32(a0), vreinterpretq_s64_s32(a1))); +#else + b0.val[0] = vcombine_s16(vreinterpret_s16_s32(vget_low_s32(a0)), + vreinterpret_s16_s32(vget_low_s32(a1))); + b0.val[1] = vcombine_s16(vreinterpret_s16_s32(vget_high_s32(a0)), + vreinterpret_s16_s32(vget_high_s32(a1))); +#endif + return b0; } static INLINE void transpose_s16_8x8(int16x8_t *a0, int16x8_t *a1, @@ -582,34 +608,29 @@ const int32x4x2_t c3 = vtrnq_s32(vreinterpretq_s32_s16(b2.val[1]), vreinterpretq_s32_s16(b3.val[1])); - *a0 = vcombine_s16(vget_low_s16(vreinterpretq_s16_s32(c0.val[0])), - vget_low_s16(vreinterpretq_s16_s32(c2.val[0]))); - *a4 = vcombine_s16(vget_high_s16(vreinterpretq_s16_s32(c0.val[0])), - vget_high_s16(vreinterpretq_s16_s32(c2.val[0]))); + // Swap 64 bit elements resulting in: + // d0.val[0]: 00 10 20 30 40 50 60 70 + // d0.val[1]: 04 14 24 34 44 54 64 74 + // d1.val[0]: 01 11 21 31 41 51 61 71 + // d1.val[1]: 05 15 25 35 45 55 65 75 + // d2.val[0]: 02 12 22 32 42 52 62 72 + // d2.val[1]: 06 16 26 36 46 56 66 76 + // d3.val[0]: 03 13 23 33 43 53 63 73 + // d3.val[1]: 07 17 27 37 47 57 67 77 - *a2 = vcombine_s16(vget_low_s16(vreinterpretq_s16_s32(c0.val[1])), - vget_low_s16(vreinterpretq_s16_s32(c2.val[1]))); - *a6 = vcombine_s16(vget_high_s16(vreinterpretq_s16_s32(c0.val[1])), - vget_high_s16(vreinterpretq_s16_s32(c2.val[1]))); + const int16x8x2_t d0 = aom_vtrnq_s64_to_s16(c0.val[0], c2.val[0]); + const int16x8x2_t d1 = aom_vtrnq_s64_to_s16(c1.val[0], c3.val[0]); + const int16x8x2_t d2 = aom_vtrnq_s64_to_s16(c0.val[1], c2.val[1]); + const int16x8x2_t d3 = aom_vtrnq_s64_to_s16(c1.val[1], c3.val[1]); - *a1 = vcombine_s16(vget_low_s16(vreinterpretq_s16_s32(c1.val[0])), - vget_low_s16(vreinterpretq_s16_s32(c3.val[0]))); - *a5 = vcombine_s16(vget_high_s16(vreinterpretq_s16_s32(c1.val[0])), - vget_high_s16(vreinterpretq_s16_s32(c3.val[0]))); - - *a3 = vcombine_s16(vget_low_s16(vreinterpretq_s16_s32(c1.val[1])), - vget_low_s16(vreinterpretq_s16_s32(c3.val[1]))); - *a7 = vcombine_s16(vget_high_s16(vreinterpretq_s16_s32(c1.val[1])), - vget_high_s16(vreinterpretq_s16_s32(c3.val[1]))); -} - -static INLINE int16x8x2_t aom_vtrnq_s64_to_s16(int32x4_t a0, int32x4_t a1) { - int16x8x2_t b0; - b0.val[0] = vcombine_s16(vreinterpret_s16_s32(vget_low_s32(a0)), - vreinterpret_s16_s32(vget_low_s32(a1))); - b0.val[1] = vcombine_s16(vreinterpret_s16_s32(vget_high_s32(a0)), - vreinterpret_s16_s32(vget_high_s32(a1))); - return b0; + *a0 = d0.val[0]; + *a1 = d1.val[0]; + *a2 = d2.val[0]; + *a3 = d3.val[0]; + *a4 = d0.val[1]; + *a5 = d1.val[1]; + *a6 = d2.val[1]; + *a7 = d3.val[1]; } static INLINE void transpose_s16_8x8q(int16x8_t *a0, int16x8_t *out) { @@ -665,6 +686,7 @@ // d2.val[1]: 06 16 26 36 46 56 66 76 // d3.val[0]: 03 13 23 33 43 53 63 73 // d3.val[1]: 07 17 27 37 47 57 67 77 + const int16x8x2_t d0 = aom_vtrnq_s64_to_s16(c0.val[0], c2.val[0]); const int16x8x2_t d1 = aom_vtrnq_s64_to_s16(c1.val[0], c3.val[0]); const int16x8x2_t d2 = aom_vtrnq_s64_to_s16(c0.val[1], c2.val[1]);