Replace 64-bit shifts with vextq in cdef_find_dir_neon

cdef_find_dir_neon uses shifts over 64-bit elements to shuffle the data
for each direction, which is rather inefficient. Replace all shifts with
EXT instructions, which gives around 50% speedup for this function.

Change-Id: I885e10e6dd199c9f9a33cb9a43ec9d45431212b6
diff --git a/av1/common/arm/cdef_block_neon.c b/av1/common/arm/cdef_block_neon.c
index 24d4cf9..a070927 100644
--- a/av1/common/arm/cdef_block_neon.c
+++ b/av1/common/arm/cdef_block_neon.c
@@ -74,92 +74,6 @@
   } while (--height != 0);
 }
 
-static INLINE int16x8_t v128_from_64_neon(int64_t a, int64_t b) {
-  return vreinterpretq_s16_s64(vcombine_s64(vcreate_s64(a), vcreate_s64(b)));
-}
-
-#define SHL_HIGH_NEON(n)                                                       \
-  static INLINE int16x8_t v128_shl_##n##_byte_neon(int16x8_t a) {              \
-    int64x2_t a_s64 = vreinterpretq_s64_s16(a);                                \
-    return v128_from_64_neon(                                                  \
-        0, vget_lane_u64(vshl_n_u64(vreinterpret_u64_s64(vget_low_s64(a_s64)), \
-                                    (n - 8) * 8),                              \
-                         0));                                                  \
-  }
-
-#define SHL_NEON(n)                                                      \
-  static INLINE int16x8_t v128_shl_##n##_byte_neon(int16x8_t a) {        \
-    int64x2_t a_s64 = vreinterpretq_s64_s16(a);                          \
-    return v128_from_64_neon(                                            \
-        0, vget_lane_u64(vreinterpret_u64_s64(vget_low_s64(a_s64)), 0)); \
-  }
-
-#define SHL_LOW_NEON(n)                                                        \
-  static INLINE int16x8_t v128_shl_##n##_byte_neon(int16x8_t a) {              \
-    int64x2_t a_s64 = vreinterpretq_s64_s16(a);                                \
-    return v128_from_64_neon(                                                  \
-        vget_lane_u64(                                                         \
-            vshl_n_u64(vreinterpret_u64_s64(vget_low_s64(a_s64)), n * 8), 0),  \
-        vget_lane_u64(                                                         \
-            vorr_u64(                                                          \
-                vshl_n_u64(vreinterpret_u64_s64(vget_high_s64(a_s64)), n * 8), \
-                vshr_n_u64(vreinterpret_u64_s64(vget_low_s64(a_s64)),          \
-                           (8 - n) * 8)),                                      \
-            0));                                                               \
-  }
-
-SHL_HIGH_NEON(14)
-SHL_HIGH_NEON(12)
-SHL_HIGH_NEON(10)
-SHL_NEON(8)
-SHL_LOW_NEON(6)
-SHL_LOW_NEON(4)
-SHL_LOW_NEON(2)
-
-#define v128_shl_n_byte_neon(a, n) v128_shl_##n##_byte_neon(a)
-
-#define SHR_HIGH_NEON(n)                                                     \
-  static INLINE int16x8_t v128_shr_##n##_byte_neon(int16x8_t a) {            \
-    int64x2_t a_s64 = vreinterpretq_s64_s16(a);                              \
-    return v128_from_64_neon(                                                \
-        vget_lane_u64(vshr_n_u64(vreinterpret_u64_s64(vget_high_s64(a_s64)), \
-                                 (n - 8) * 8),                               \
-                      0),                                                    \
-        0);                                                                  \
-  }
-
-#define SHR_NEON(n)                                                       \
-  static INLINE int16x8_t v128_shr_##n##_byte_neon(int16x8_t a) {         \
-    int64x2_t a_s64 = vreinterpretq_s64_s16(a);                           \
-    return v128_from_64_neon(                                             \
-        vget_lane_u64(vreinterpret_u64_s64(vget_high_s64(a_s64)), 0), 0); \
-  }
-
-#define SHR_LOW_NEON(n)                                                       \
-  static INLINE int16x8_t v128_shr_##n##_byte_neon(int16x8_t a) {             \
-    int64x2_t a_s64 = vreinterpretq_s64_s16(a);                               \
-    return v128_from_64_neon(                                                 \
-        vget_lane_u64(                                                        \
-            vorr_u64(                                                         \
-                vshr_n_u64(vreinterpret_u64_s64(vget_low_s64(a_s64)), n * 8), \
-                vshl_n_u64(vreinterpret_u64_s64(vget_high_s64(a_s64)),        \
-                           (8 - n) * 8)),                                     \
-            0),                                                               \
-        vget_lane_u64(                                                        \
-            vshr_n_u64(vreinterpret_u64_s64(vget_high_s64(a_s64)), n * 8),    \
-            0));                                                              \
-  }
-
-SHR_HIGH_NEON(14)
-SHR_HIGH_NEON(12)
-SHR_HIGH_NEON(10)
-SHR_NEON(8)
-SHR_LOW_NEON(6)
-SHR_LOW_NEON(4)
-SHR_LOW_NEON(2)
-
-#define v128_shr_n_byte_neon(a, n) v128_shr_##n##_byte_neon(a)
-
 static INLINE uint32x4_t v128_madd_s16_neon(int16x8_t a, int16x8_t b) {
   uint32x4_t t1 =
       vreinterpretq_u32_s32(vmull_s16(vget_low_s16(a), vget_low_s16(b)));
@@ -217,57 +131,87 @@
   return partiala_u32;
 }
 
+// This function is called a first time to compute the cost along directions 4,
+// 5, 6, 7, and then a second time on a rotated block to compute directions
+// 0, 1, 2, 3. (0 means 45-degree up-right, 2 is horizontal, and so on.)
+//
+// For each direction the lines are shifted so that we can perform a
+// basic sum on each vector element. For example, direction 5 is "south by
+// southeast", so we need to add the pixels along each line i below:
+//
+// 0  1 2 3 4 5 6 7
+// 0  1 2 3 4 5 6 7
+// 8  0 1 2 3 4 5 6
+// 8  0 1 2 3 4 5 6
+// 9  8 0 1 2 3 4 5
+// 9  8 0 1 2 3 4 5
+// 10 9 8 0 1 2 3 4
+// 10 9 8 0 1 2 3 4
+//
+// For this to fit nicely in vectors, the lines need to be shifted like so:
+//        0 1 2 3 4 5 6 7
+//        0 1 2 3 4 5 6 7
+//      8 0 1 2 3 4 5 6
+//      8 0 1 2 3 4 5 6
+//    9 8 0 1 2 3 4 5
+//    9 8 0 1 2 3 4 5
+// 10 9 8 0 1 2 3 4
+// 10 9 8 0 1 2 3 4
+//
+// In this configuration we can now perform SIMD additions to get the cost
+// along direction 5. Since this won't fit into a single 128-bit vector, we use
+// two of them to compute each half of the new configuration, and pad the empty
+// spaces with zeros. Similar shifting is done for other directions, except
+// direction 6 which is straightforward as it's the vertical direction.
 static INLINE uint32x4_t compute_directions_neon(int16x8_t lines[8],
                                                  uint32_t cost[4]) {
-  int16x8_t partial4a, partial4b, partial5a, partial5b, partial6, partial7a,
-      partial7b;
-  int16x8_t tmp;
+  const int16x8_t zero = vdupq_n_s16(0);
 
   // Partial sums for lines 0 and 1.
-  partial4a = v128_shl_n_byte_neon(lines[0], 14);
-  partial4b = v128_shr_n_byte_neon(lines[0], 2);
-  partial4a = vaddq_s16(partial4a, v128_shl_n_byte_neon(lines[1], 12));
-  partial4b = vaddq_s16(partial4b, v128_shr_n_byte_neon(lines[1], 4));
-  tmp = vaddq_s16(lines[0], lines[1]);
-  partial5a = v128_shl_n_byte_neon(tmp, 10);
-  partial5b = v128_shr_n_byte_neon(tmp, 6);
-  partial7a = v128_shl_n_byte_neon(tmp, 4);
-  partial7b = v128_shr_n_byte_neon(tmp, 12);
-  partial6 = tmp;
+  int16x8_t partial4a = vextq_s16(zero, lines[0], 1);
+  partial4a = vaddq_s16(partial4a, vextq_s16(zero, lines[1], 2));
+  int16x8_t partial4b = vextq_s16(lines[0], zero, 1);
+  partial4b = vaddq_s16(partial4b, vextq_s16(lines[1], zero, 2));
+  int16x8_t tmp = vaddq_s16(lines[0], lines[1]);
+  int16x8_t partial5a = vextq_s16(zero, tmp, 3);
+  int16x8_t partial5b = vextq_s16(tmp, zero, 3);
+  int16x8_t partial7a = vextq_s16(zero, tmp, 6);
+  int16x8_t partial7b = vextq_s16(tmp, zero, 6);
+  int16x8_t partial6 = tmp;
 
   // Partial sums for lines 2 and 3.
-  partial4a = vaddq_s16(partial4a, v128_shl_n_byte_neon(lines[2], 10));
-  partial4b = vaddq_s16(partial4b, v128_shr_n_byte_neon(lines[2], 6));
-  partial4a = vaddq_s16(partial4a, v128_shl_n_byte_neon(lines[3], 8));
-  partial4b = vaddq_s16(partial4b, v128_shr_n_byte_neon(lines[3], 8));
+  partial4a = vaddq_s16(partial4a, vextq_s16(zero, lines[2], 3));
+  partial4a = vaddq_s16(partial4a, vextq_s16(zero, lines[3], 4));
+  partial4b = vaddq_s16(partial4b, vextq_s16(lines[2], zero, 3));
+  partial4b = vaddq_s16(partial4b, vextq_s16(lines[3], zero, 4));
   tmp = vaddq_s16(lines[2], lines[3]);
-  partial5a = vaddq_s16(partial5a, v128_shl_n_byte_neon(tmp, 8));
-  partial5b = vaddq_s16(partial5b, v128_shr_n_byte_neon(tmp, 8));
-  partial7a = vaddq_s16(partial7a, v128_shl_n_byte_neon(tmp, 6));
-  partial7b = vaddq_s16(partial7b, v128_shr_n_byte_neon(tmp, 10));
+  partial5a = vaddq_s16(partial5a, vextq_s16(zero, tmp, 4));
+  partial5b = vaddq_s16(partial5b, vextq_s16(tmp, zero, 4));
+  partial7a = vaddq_s16(partial7a, vextq_s16(zero, tmp, 5));
+  partial7b = vaddq_s16(partial7b, vextq_s16(tmp, zero, 5));
   partial6 = vaddq_s16(partial6, tmp);
 
   // Partial sums for lines 4 and 5.
-  partial4a = vaddq_s16(partial4a, v128_shl_n_byte_neon(lines[4], 6));
-  partial4b = vaddq_s16(partial4b, v128_shr_n_byte_neon(lines[4], 10));
-  partial4a = vaddq_s16(partial4a, v128_shl_n_byte_neon(lines[5], 4));
-  partial4b = vaddq_s16(partial4b, v128_shr_n_byte_neon(lines[5], 12));
+  partial4a = vaddq_s16(partial4a, vextq_s16(zero, lines[4], 5));
+  partial4a = vaddq_s16(partial4a, vextq_s16(zero, lines[5], 6));
+  partial4b = vaddq_s16(partial4b, vextq_s16(lines[4], zero, 5));
+  partial4b = vaddq_s16(partial4b, vextq_s16(lines[5], zero, 6));
   tmp = vaddq_s16(lines[4], lines[5]);
-  partial5a = vaddq_s16(partial5a, v128_shl_n_byte_neon(tmp, 6));
-  partial5b = vaddq_s16(partial5b, v128_shr_n_byte_neon(tmp, 10));
-  partial7a = vaddq_s16(partial7a, v128_shl_n_byte_neon(tmp, 8));
-  partial7b = vaddq_s16(partial7b, v128_shr_n_byte_neon(tmp, 8));
+  partial5a = vaddq_s16(partial5a, vextq_s16(zero, tmp, 5));
+  partial5b = vaddq_s16(partial5b, vextq_s16(tmp, zero, 5));
+  partial7a = vaddq_s16(partial7a, vextq_s16(zero, tmp, 4));
+  partial7b = vaddq_s16(partial7b, vextq_s16(tmp, zero, 4));
   partial6 = vaddq_s16(partial6, tmp);
 
   // Partial sums for lines 6 and 7.
-  partial4a = vaddq_s16(partial4a, v128_shl_n_byte_neon(lines[6], 2));
-  partial4b = vaddq_s16(partial4b, v128_shr_n_byte_neon(lines[6], 14));
+  partial4a = vaddq_s16(partial4a, vextq_s16(zero, lines[6], 7));
   partial4a = vaddq_s16(partial4a, lines[7]);
+  partial4b = vaddq_s16(partial4b, vextq_s16(lines[6], zero, 7));
   tmp = vaddq_s16(lines[6], lines[7]);
-  partial5a = vaddq_s16(partial5a, v128_shl_n_byte_neon(tmp, 4));
-  partial5b = vaddq_s16(partial5b, v128_shr_n_byte_neon(tmp, 12));
-  partial7a = vaddq_s16(partial7a, v128_shl_n_byte_neon(tmp, 10));
-  partial7b = vaddq_s16(partial7b, v128_shr_n_byte_neon(tmp, 6));
+  partial5a = vaddq_s16(partial5a, vextq_s16(zero, tmp, 6));
+  partial5b = vaddq_s16(partial5b, vextq_s16(tmp, zero, 6));
+  partial7a = vaddq_s16(partial7a, vextq_s16(zero, tmp, 3));
+  partial7b = vaddq_s16(partial7b, vextq_s16(tmp, zero, 3));
   partial6 = vaddq_s16(partial6, tmp);
 
   uint32x4_t const0 = vreinterpretq_u32_u64(