Refactor 8x8 16-bit Neon transpose functions

Refactor the Neon implementation of transpose_s16_8x8(q) and
transpose_u16_8x8 so that the final step compiles to 8 ZIP1/ZIP2
instructions as opposed to 8 EXT, MOV pairs. This change removes 8
instructions per call to transpose_s16_8x8(q), transpose_u16_8x8
where the result stays in registers for further processing - rather
than being stored to memory - like in aom_hadamard_8x8_neon, for
example.

Co-authored-by: Jonathan Wright <jonathan.wright@arm.com>
Change-Id: I470442d3392acf38c12817b87bdaa46eee887ff6
diff --git a/aom_dsp/arm/transpose_neon.h b/aom_dsp/arm/transpose_neon.h
index 26fc1fd..68ec397 100644
--- a/aom_dsp/arm/transpose_neon.h
+++ b/aom_dsp/arm/transpose_neon.h
@@ -258,13 +258,19 @@
   a[3] = vreinterpretq_u16_u32(c1.val[1]);
 }
 
-static INLINE uint16x8x2_t aom_vtrnq_u64_to_u16(const uint32x4_t a0,
-                                                const uint32x4_t a1) {
+static INLINE uint16x8x2_t aom_vtrnq_u64_to_u16(uint32x4_t a0, uint32x4_t a1) {
   uint16x8x2_t b0;
+#if defined(__aarch64__)
+  b0.val[0] = vreinterpretq_u16_u64(
+      vtrn1q_u64(vreinterpretq_u64_u32(a0), vreinterpretq_u64_u32(a1)));
+  b0.val[1] = vreinterpretq_u16_u64(
+      vtrn2q_u64(vreinterpretq_u64_u32(a0), vreinterpretq_u64_u32(a1)));
+#else
   b0.val[0] = vcombine_u16(vreinterpret_u16_u32(vget_low_u32(a0)),
                            vreinterpret_u16_u32(vget_low_u32(a1)));
   b0.val[1] = vcombine_u16(vreinterpret_u16_u32(vget_high_u32(a0)),
                            vreinterpret_u16_u32(vget_high_u32(a1)));
+#endif
   return b0;
 }
 
@@ -514,25 +520,45 @@
   const uint32x4x2_t c3 = vtrnq_u32(vreinterpretq_u32_u16(b2.val[1]),
                                     vreinterpretq_u32_u16(b3.val[1]));
 
-  *a0 = vcombine_u16(vget_low_u16(vreinterpretq_u16_u32(c0.val[0])),
-                     vget_low_u16(vreinterpretq_u16_u32(c2.val[0])));
-  *a4 = vcombine_u16(vget_high_u16(vreinterpretq_u16_u32(c0.val[0])),
-                     vget_high_u16(vreinterpretq_u16_u32(c2.val[0])));
+  // Swap 64 bit elements resulting in:
+  // d0.val[0]: 00 10 20 30 40 50 60 70
+  // d0.val[1]: 04 14 24 34 44 54 64 74
+  // d1.val[0]: 01 11 21 31 41 51 61 71
+  // d1.val[1]: 05 15 25 35 45 55 65 75
+  // d2.val[0]: 02 12 22 32 42 52 62 72
+  // d2.val[1]: 06 16 26 36 46 56 66 76
+  // d3.val[0]: 03 13 23 33 43 53 63 73
+  // d3.val[1]: 07 17 27 37 47 57 67 77
 
-  *a2 = vcombine_u16(vget_low_u16(vreinterpretq_u16_u32(c0.val[1])),
-                     vget_low_u16(vreinterpretq_u16_u32(c2.val[1])));
-  *a6 = vcombine_u16(vget_high_u16(vreinterpretq_u16_u32(c0.val[1])),
-                     vget_high_u16(vreinterpretq_u16_u32(c2.val[1])));
+  const uint16x8x2_t d0 = aom_vtrnq_u64_to_u16(c0.val[0], c2.val[0]);
+  const uint16x8x2_t d1 = aom_vtrnq_u64_to_u16(c1.val[0], c3.val[0]);
+  const uint16x8x2_t d2 = aom_vtrnq_u64_to_u16(c0.val[1], c2.val[1]);
+  const uint16x8x2_t d3 = aom_vtrnq_u64_to_u16(c1.val[1], c3.val[1]);
 
-  *a1 = vcombine_u16(vget_low_u16(vreinterpretq_u16_u32(c1.val[0])),
-                     vget_low_u16(vreinterpretq_u16_u32(c3.val[0])));
-  *a5 = vcombine_u16(vget_high_u16(vreinterpretq_u16_u32(c1.val[0])),
-                     vget_high_u16(vreinterpretq_u16_u32(c3.val[0])));
+  *a0 = d0.val[0];
+  *a1 = d1.val[0];
+  *a2 = d2.val[0];
+  *a3 = d3.val[0];
+  *a4 = d0.val[1];
+  *a5 = d1.val[1];
+  *a6 = d2.val[1];
+  *a7 = d3.val[1];
+}
 
-  *a3 = vcombine_u16(vget_low_u16(vreinterpretq_u16_u32(c1.val[1])),
-                     vget_low_u16(vreinterpretq_u16_u32(c3.val[1])));
-  *a7 = vcombine_u16(vget_high_u16(vreinterpretq_u16_u32(c1.val[1])),
-                     vget_high_u16(vreinterpretq_u16_u32(c3.val[1])));
+static INLINE int16x8x2_t aom_vtrnq_s64_to_s16(int32x4_t a0, int32x4_t a1) {
+  int16x8x2_t b0;
+#if defined(__aarch64__)
+  b0.val[0] = vreinterpretq_s16_s64(
+      vtrn1q_s64(vreinterpretq_s64_s32(a0), vreinterpretq_s64_s32(a1)));
+  b0.val[1] = vreinterpretq_s16_s64(
+      vtrn2q_s64(vreinterpretq_s64_s32(a0), vreinterpretq_s64_s32(a1)));
+#else
+  b0.val[0] = vcombine_s16(vreinterpret_s16_s32(vget_low_s32(a0)),
+                           vreinterpret_s16_s32(vget_low_s32(a1)));
+  b0.val[1] = vcombine_s16(vreinterpret_s16_s32(vget_high_s32(a0)),
+                           vreinterpret_s16_s32(vget_high_s32(a1)));
+#endif
+  return b0;
 }
 
 static INLINE void transpose_s16_8x8(int16x8_t *a0, int16x8_t *a1,
@@ -582,34 +608,29 @@
   const int32x4x2_t c3 = vtrnq_s32(vreinterpretq_s32_s16(b2.val[1]),
                                    vreinterpretq_s32_s16(b3.val[1]));
 
-  *a0 = vcombine_s16(vget_low_s16(vreinterpretq_s16_s32(c0.val[0])),
-                     vget_low_s16(vreinterpretq_s16_s32(c2.val[0])));
-  *a4 = vcombine_s16(vget_high_s16(vreinterpretq_s16_s32(c0.val[0])),
-                     vget_high_s16(vreinterpretq_s16_s32(c2.val[0])));
+  // Swap 64 bit elements resulting in:
+  // d0.val[0]: 00 10 20 30 40 50 60 70
+  // d0.val[1]: 04 14 24 34 44 54 64 74
+  // d1.val[0]: 01 11 21 31 41 51 61 71
+  // d1.val[1]: 05 15 25 35 45 55 65 75
+  // d2.val[0]: 02 12 22 32 42 52 62 72
+  // d2.val[1]: 06 16 26 36 46 56 66 76
+  // d3.val[0]: 03 13 23 33 43 53 63 73
+  // d3.val[1]: 07 17 27 37 47 57 67 77
 
-  *a2 = vcombine_s16(vget_low_s16(vreinterpretq_s16_s32(c0.val[1])),
-                     vget_low_s16(vreinterpretq_s16_s32(c2.val[1])));
-  *a6 = vcombine_s16(vget_high_s16(vreinterpretq_s16_s32(c0.val[1])),
-                     vget_high_s16(vreinterpretq_s16_s32(c2.val[1])));
+  const int16x8x2_t d0 = aom_vtrnq_s64_to_s16(c0.val[0], c2.val[0]);
+  const int16x8x2_t d1 = aom_vtrnq_s64_to_s16(c1.val[0], c3.val[0]);
+  const int16x8x2_t d2 = aom_vtrnq_s64_to_s16(c0.val[1], c2.val[1]);
+  const int16x8x2_t d3 = aom_vtrnq_s64_to_s16(c1.val[1], c3.val[1]);
 
-  *a1 = vcombine_s16(vget_low_s16(vreinterpretq_s16_s32(c1.val[0])),
-                     vget_low_s16(vreinterpretq_s16_s32(c3.val[0])));
-  *a5 = vcombine_s16(vget_high_s16(vreinterpretq_s16_s32(c1.val[0])),
-                     vget_high_s16(vreinterpretq_s16_s32(c3.val[0])));
-
-  *a3 = vcombine_s16(vget_low_s16(vreinterpretq_s16_s32(c1.val[1])),
-                     vget_low_s16(vreinterpretq_s16_s32(c3.val[1])));
-  *a7 = vcombine_s16(vget_high_s16(vreinterpretq_s16_s32(c1.val[1])),
-                     vget_high_s16(vreinterpretq_s16_s32(c3.val[1])));
-}
-
-static INLINE int16x8x2_t aom_vtrnq_s64_to_s16(int32x4_t a0, int32x4_t a1) {
-  int16x8x2_t b0;
-  b0.val[0] = vcombine_s16(vreinterpret_s16_s32(vget_low_s32(a0)),
-                           vreinterpret_s16_s32(vget_low_s32(a1)));
-  b0.val[1] = vcombine_s16(vreinterpret_s16_s32(vget_high_s32(a0)),
-                           vreinterpret_s16_s32(vget_high_s32(a1)));
-  return b0;
+  *a0 = d0.val[0];
+  *a1 = d1.val[0];
+  *a2 = d2.val[0];
+  *a3 = d3.val[0];
+  *a4 = d0.val[1];
+  *a5 = d1.val[1];
+  *a6 = d2.val[1];
+  *a7 = d3.val[1];
 }
 
 static INLINE void transpose_s16_8x8q(int16x8_t *a0, int16x8_t *out) {
@@ -665,6 +686,7 @@
   // d2.val[1]: 06 16 26 36 46 56 66 76
   // d3.val[0]: 03 13 23 33 43 53 63 73
   // d3.val[1]: 07 17 27 37 47 57 67 77
+
   const int16x8x2_t d0 = aom_vtrnq_s64_to_s16(c0.val[0], c2.val[0]);
   const int16x8x2_t d1 = aom_vtrnq_s64_to_s16(c1.val[0], c3.val[0]);
   const int16x8x2_t d2 = aom_vtrnq_s64_to_s16(c0.val[1], c2.val[1]);