Refactor 8x8 16-bit Neon transpose functions
Refactor the Neon implementation of transpose_s16_8x8(q) and
transpose_u16_8x8 so that the final step compiles to 8 ZIP1/ZIP2
instructions as opposed to 8 EXT, MOV pairs. This change removes 8
instructions per call to transpose_s16_8x8(q), transpose_u16_8x8
where the result stays in registers for further processing - rather
than being stored to memory - like in aom_hadamard_8x8_neon, for
example.
Co-authored-by: Jonathan Wright <jonathan.wright@arm.com>
Change-Id: I470442d3392acf38c12817b87bdaa46eee887ff6
diff --git a/aom_dsp/arm/transpose_neon.h b/aom_dsp/arm/transpose_neon.h
index 26fc1fd..68ec397 100644
--- a/aom_dsp/arm/transpose_neon.h
+++ b/aom_dsp/arm/transpose_neon.h
@@ -258,13 +258,19 @@
a[3] = vreinterpretq_u16_u32(c1.val[1]);
}
-static INLINE uint16x8x2_t aom_vtrnq_u64_to_u16(const uint32x4_t a0,
- const uint32x4_t a1) {
+static INLINE uint16x8x2_t aom_vtrnq_u64_to_u16(uint32x4_t a0, uint32x4_t a1) {
uint16x8x2_t b0;
+#if defined(__aarch64__)
+ b0.val[0] = vreinterpretq_u16_u64(
+ vtrn1q_u64(vreinterpretq_u64_u32(a0), vreinterpretq_u64_u32(a1)));
+ b0.val[1] = vreinterpretq_u16_u64(
+ vtrn2q_u64(vreinterpretq_u64_u32(a0), vreinterpretq_u64_u32(a1)));
+#else
b0.val[0] = vcombine_u16(vreinterpret_u16_u32(vget_low_u32(a0)),
vreinterpret_u16_u32(vget_low_u32(a1)));
b0.val[1] = vcombine_u16(vreinterpret_u16_u32(vget_high_u32(a0)),
vreinterpret_u16_u32(vget_high_u32(a1)));
+#endif
return b0;
}
@@ -514,25 +520,45 @@
const uint32x4x2_t c3 = vtrnq_u32(vreinterpretq_u32_u16(b2.val[1]),
vreinterpretq_u32_u16(b3.val[1]));
- *a0 = vcombine_u16(vget_low_u16(vreinterpretq_u16_u32(c0.val[0])),
- vget_low_u16(vreinterpretq_u16_u32(c2.val[0])));
- *a4 = vcombine_u16(vget_high_u16(vreinterpretq_u16_u32(c0.val[0])),
- vget_high_u16(vreinterpretq_u16_u32(c2.val[0])));
+ // Swap 64 bit elements resulting in:
+ // d0.val[0]: 00 10 20 30 40 50 60 70
+ // d0.val[1]: 04 14 24 34 44 54 64 74
+ // d1.val[0]: 01 11 21 31 41 51 61 71
+ // d1.val[1]: 05 15 25 35 45 55 65 75
+ // d2.val[0]: 02 12 22 32 42 52 62 72
+ // d2.val[1]: 06 16 26 36 46 56 66 76
+ // d3.val[0]: 03 13 23 33 43 53 63 73
+ // d3.val[1]: 07 17 27 37 47 57 67 77
- *a2 = vcombine_u16(vget_low_u16(vreinterpretq_u16_u32(c0.val[1])),
- vget_low_u16(vreinterpretq_u16_u32(c2.val[1])));
- *a6 = vcombine_u16(vget_high_u16(vreinterpretq_u16_u32(c0.val[1])),
- vget_high_u16(vreinterpretq_u16_u32(c2.val[1])));
+ const uint16x8x2_t d0 = aom_vtrnq_u64_to_u16(c0.val[0], c2.val[0]);
+ const uint16x8x2_t d1 = aom_vtrnq_u64_to_u16(c1.val[0], c3.val[0]);
+ const uint16x8x2_t d2 = aom_vtrnq_u64_to_u16(c0.val[1], c2.val[1]);
+ const uint16x8x2_t d3 = aom_vtrnq_u64_to_u16(c1.val[1], c3.val[1]);
- *a1 = vcombine_u16(vget_low_u16(vreinterpretq_u16_u32(c1.val[0])),
- vget_low_u16(vreinterpretq_u16_u32(c3.val[0])));
- *a5 = vcombine_u16(vget_high_u16(vreinterpretq_u16_u32(c1.val[0])),
- vget_high_u16(vreinterpretq_u16_u32(c3.val[0])));
+ *a0 = d0.val[0];
+ *a1 = d1.val[0];
+ *a2 = d2.val[0];
+ *a3 = d3.val[0];
+ *a4 = d0.val[1];
+ *a5 = d1.val[1];
+ *a6 = d2.val[1];
+ *a7 = d3.val[1];
+}
- *a3 = vcombine_u16(vget_low_u16(vreinterpretq_u16_u32(c1.val[1])),
- vget_low_u16(vreinterpretq_u16_u32(c3.val[1])));
- *a7 = vcombine_u16(vget_high_u16(vreinterpretq_u16_u32(c1.val[1])),
- vget_high_u16(vreinterpretq_u16_u32(c3.val[1])));
+static INLINE int16x8x2_t aom_vtrnq_s64_to_s16(int32x4_t a0, int32x4_t a1) {
+ int16x8x2_t b0;
+#if defined(__aarch64__)
+ b0.val[0] = vreinterpretq_s16_s64(
+ vtrn1q_s64(vreinterpretq_s64_s32(a0), vreinterpretq_s64_s32(a1)));
+ b0.val[1] = vreinterpretq_s16_s64(
+ vtrn2q_s64(vreinterpretq_s64_s32(a0), vreinterpretq_s64_s32(a1)));
+#else
+ b0.val[0] = vcombine_s16(vreinterpret_s16_s32(vget_low_s32(a0)),
+ vreinterpret_s16_s32(vget_low_s32(a1)));
+ b0.val[1] = vcombine_s16(vreinterpret_s16_s32(vget_high_s32(a0)),
+ vreinterpret_s16_s32(vget_high_s32(a1)));
+#endif
+ return b0;
}
static INLINE void transpose_s16_8x8(int16x8_t *a0, int16x8_t *a1,
@@ -582,34 +608,29 @@
const int32x4x2_t c3 = vtrnq_s32(vreinterpretq_s32_s16(b2.val[1]),
vreinterpretq_s32_s16(b3.val[1]));
- *a0 = vcombine_s16(vget_low_s16(vreinterpretq_s16_s32(c0.val[0])),
- vget_low_s16(vreinterpretq_s16_s32(c2.val[0])));
- *a4 = vcombine_s16(vget_high_s16(vreinterpretq_s16_s32(c0.val[0])),
- vget_high_s16(vreinterpretq_s16_s32(c2.val[0])));
+ // Swap 64 bit elements resulting in:
+ // d0.val[0]: 00 10 20 30 40 50 60 70
+ // d0.val[1]: 04 14 24 34 44 54 64 74
+ // d1.val[0]: 01 11 21 31 41 51 61 71
+ // d1.val[1]: 05 15 25 35 45 55 65 75
+ // d2.val[0]: 02 12 22 32 42 52 62 72
+ // d2.val[1]: 06 16 26 36 46 56 66 76
+ // d3.val[0]: 03 13 23 33 43 53 63 73
+ // d3.val[1]: 07 17 27 37 47 57 67 77
- *a2 = vcombine_s16(vget_low_s16(vreinterpretq_s16_s32(c0.val[1])),
- vget_low_s16(vreinterpretq_s16_s32(c2.val[1])));
- *a6 = vcombine_s16(vget_high_s16(vreinterpretq_s16_s32(c0.val[1])),
- vget_high_s16(vreinterpretq_s16_s32(c2.val[1])));
+ const int16x8x2_t d0 = aom_vtrnq_s64_to_s16(c0.val[0], c2.val[0]);
+ const int16x8x2_t d1 = aom_vtrnq_s64_to_s16(c1.val[0], c3.val[0]);
+ const int16x8x2_t d2 = aom_vtrnq_s64_to_s16(c0.val[1], c2.val[1]);
+ const int16x8x2_t d3 = aom_vtrnq_s64_to_s16(c1.val[1], c3.val[1]);
- *a1 = vcombine_s16(vget_low_s16(vreinterpretq_s16_s32(c1.val[0])),
- vget_low_s16(vreinterpretq_s16_s32(c3.val[0])));
- *a5 = vcombine_s16(vget_high_s16(vreinterpretq_s16_s32(c1.val[0])),
- vget_high_s16(vreinterpretq_s16_s32(c3.val[0])));
-
- *a3 = vcombine_s16(vget_low_s16(vreinterpretq_s16_s32(c1.val[1])),
- vget_low_s16(vreinterpretq_s16_s32(c3.val[1])));
- *a7 = vcombine_s16(vget_high_s16(vreinterpretq_s16_s32(c1.val[1])),
- vget_high_s16(vreinterpretq_s16_s32(c3.val[1])));
-}
-
-static INLINE int16x8x2_t aom_vtrnq_s64_to_s16(int32x4_t a0, int32x4_t a1) {
- int16x8x2_t b0;
- b0.val[0] = vcombine_s16(vreinterpret_s16_s32(vget_low_s32(a0)),
- vreinterpret_s16_s32(vget_low_s32(a1)));
- b0.val[1] = vcombine_s16(vreinterpret_s16_s32(vget_high_s32(a0)),
- vreinterpret_s16_s32(vget_high_s32(a1)));
- return b0;
+ *a0 = d0.val[0];
+ *a1 = d1.val[0];
+ *a2 = d2.val[0];
+ *a3 = d3.val[0];
+ *a4 = d0.val[1];
+ *a5 = d1.val[1];
+ *a6 = d2.val[1];
+ *a7 = d3.val[1];
}
static INLINE void transpose_s16_8x8q(int16x8_t *a0, int16x8_t *out) {
@@ -665,6 +686,7 @@
// d2.val[1]: 06 16 26 36 46 56 66 76
// d3.val[0]: 03 13 23 33 43 53 63 73
// d3.val[1]: 07 17 27 37 47 57 67 77
+
const int16x8x2_t d0 = aom_vtrnq_s64_to_s16(c0.val[0], c2.val[0]);
const int16x8x2_t d1 = aom_vtrnq_s64_to_s16(c1.val[0], c3.val[0]);
const int16x8x2_t d2 = aom_vtrnq_s64_to_s16(c0.val[1], c2.val[1]);