Add specific path for horizontal dirs in cdef_find_dir_neon

compute_directions was used for both vertical and horizontal directions,
with a rotation of the input block between the two calls. Remove the
rotation and add a specific function to compute horizontal directions.
This gives around 9% uplift for cdef_find_dir for both Clang and GCC.

Change-Id: Id68f4a6a43ecf0cf830e73e2366a16bc2a381bda
diff --git a/aom_dsp/arm/sum_neon.h b/aom_dsp/arm/sum_neon.h
index 4261592..30a108e 100644
--- a/aom_dsp/arm/sum_neon.h
+++ b/aom_dsp/arm/sum_neon.h
@@ -196,6 +196,23 @@
 #endif
 }
 
+static INLINE int32x4_t horizontal_add_4d_s16x8(const int16x8_t sum[4]) {
+#if AOM_ARCH_AARCH64
+  const int16x8_t a0 = vpaddq_s16(sum[0], sum[1]);
+  const int16x8_t a1 = vpaddq_s16(sum[2], sum[3]);
+  const int16x8_t b0 = vpaddq_s16(a0, a1);
+  return vpaddlq_s16(b0);
+#else
+  const int16x4_t a0 = vadd_s16(vget_low_s16(sum[0]), vget_high_s16(sum[0]));
+  const int16x4_t a1 = vadd_s16(vget_low_s16(sum[1]), vget_high_s16(sum[1]));
+  const int16x4_t a2 = vadd_s16(vget_low_s16(sum[2]), vget_high_s16(sum[2]));
+  const int16x4_t a3 = vadd_s16(vget_low_s16(sum[3]), vget_high_s16(sum[3]));
+  const int16x4_t b0 = vpadd_s16(a0, a1);
+  const int16x4_t b1 = vpadd_s16(a2, a3);
+  return vpaddlq_s16(vcombine_s16(b0, b1));
+#endif
+}
+
 static INLINE uint32_t horizontal_add_u32x2(const uint32x2_t a) {
 #if AOM_ARCH_AARCH64
   return vaddv_u32(a);
diff --git a/av1/common/arm/cdef_block_neon.c b/av1/common/arm/cdef_block_neon.c
index 45d58aa..68a292b 100644
--- a/av1/common/arm/cdef_block_neon.c
+++ b/av1/common/arm/cdef_block_neon.c
@@ -115,9 +115,8 @@
   return cost;
 }
 
-// This function is called a first time to compute the cost along directions 4,
-// 5, 6, 7, and then a second time on a rotated block to compute directions
-// 0, 1, 2, 3. (0 means 45-degree up-right, 2 is horizontal, and so on.)
+// This function computes the cost along directions 4, 5, 6, 7. (4 is diagonal
+// down-right, 6 is vertical).
 //
 // For each direction the lines are shifted so that we can perform a
 // basic sum on each vector element. For example, direction 5 is "south by
@@ -147,8 +146,8 @@
 // two of them to compute each half of the new configuration, and pad the empty
 // spaces with zeros. Similar shifting is done for other directions, except
 // direction 6 which is straightforward as it's the vertical direction.
-static INLINE uint32x4_t compute_directions_neon(int16x8_t lines[8],
-                                                 uint32_t cost[4]) {
+static INLINE uint32x4_t compute_vert_directions_neon(int16x8_t lines[8],
+                                                      uint32_t cost[4]) {
   const int16x8_t zero = vdupq_n_s16(0);
 
   // Partial sums for lines 0 and 1.
@@ -227,46 +226,157 @@
   return costs[0];
 }
 
-static INLINE int64x2_t ziplo_s64(int32x4_t a, int32x4_t b) {
-  return vcombine_s64(vget_low_s64(vreinterpretq_s64_s32(a)),
-                      vget_low_s64(vreinterpretq_s64_s32(b)));
+static INLINE uint32x4_t fold_mul_and_sum_pairwise_neon(int16x8_t partiala,
+                                                        int16x8_t partialb,
+                                                        int16x8_t partialc,
+                                                        uint32x4_t const0) {
+  // Reverse partial c.
+  // pattern = { 10 11 8 9 6 7 4 5 2 3 0 1 12 13 14 15 }.
+  uint8x16_t pattern = vreinterpretq_u8_u64(
+      vcombine_u64(vcreate_u64((uint64_t)0x05040706 << 32 | 0x09080b0a),
+                   vcreate_u64((uint64_t)0x0f0e0d0c << 32 | 0x01000302)));
+
+#if AOM_ARCH_AARCH64
+  partialc =
+      vreinterpretq_s16_s8(vqtbl1q_s8(vreinterpretq_s8_s16(partialc), pattern));
+#else
+  int8x8x2_t p = { { vget_low_s8(vreinterpretq_s8_s16(partialc)),
+                     vget_high_s8(vreinterpretq_s8_s16(partialc)) } };
+  int8x8_t shuffle_hi = vtbl2_s8(p, vget_high_s8(vreinterpretq_s8_u8(pattern)));
+  int8x8_t shuffle_lo = vtbl2_s8(p, vget_low_s8(vreinterpretq_s8_u8(pattern)));
+  partialc = vreinterpretq_s16_s8(vcombine_s8(shuffle_lo, shuffle_hi));
+#endif
+
+  int32x4_t partiala_s32 = vpaddlq_s16(partiala);
+  int32x4_t partialb_s32 = vpaddlq_s16(partialb);
+  int32x4_t partialc_s32 = vpaddlq_s16(partialc);
+
+  partiala_s32 = vmulq_s32(partiala_s32, partiala_s32);
+  partialb_s32 = vmulq_s32(partialb_s32, partialb_s32);
+  partialc_s32 = vmulq_s32(partialc_s32, partialc_s32);
+
+  partiala_s32 = vaddq_s32(partiala_s32, partialc_s32);
+
+  uint32x4_t cost = vmulq_n_u32(vreinterpretq_u32_s32(partialb_s32), 105);
+  cost = vmlaq_u32(cost, vreinterpretq_u32_s32(partiala_s32), const0);
+  return cost;
 }
 
-static INLINE int64x2_t ziphi_s64(int32x4_t a, int32x4_t b) {
-  return vcombine_s64(vget_high_s64(vreinterpretq_s64_s32(a)),
-                      vget_high_s64(vreinterpretq_s64_s32(b)));
-}
+// This function computes the cost along directions 0, 1, 2, 3. (0 means
+// 45-degree up-right, 2 is horizontal).
+//
+// For direction 1 and 3 ("east northeast" and "east southeast") the shifted
+// lines need three vectors instead of two. For direction 1 for example, we need
+// to compute the sums along the line i below:
+// 0 0 1 1 2 2 3  3
+// 1 1 2 2 3 3 4  4
+// 2 2 3 3 4 4 5  5
+// 3 3 4 4 5 5 6  6
+// 4 4 5 5 6 6 7  7
+// 5 5 6 6 7 7 8  8
+// 6 6 7 7 8 8 9  9
+// 7 7 8 8 9 9 10 10
+//
+// Which means we need the following configuration:
+// 0 0 1 1 2 2 3 3
+//     1 1 2 2 3 3 4 4
+//         2 2 3 3 4 4 5 5
+//             3 3 4 4 5 5 6 6
+//                 4 4 5 5 6 6 7 7
+//                     5 5 6 6 7 7 8 8
+//                         6 6 7 7 8 8 9 9
+//                             7 7 8 8 9 9 10 10
+//
+// Three vectors are needed to compute this, as well as some extra pairwise
+// additions.
+static uint32x4_t compute_horiz_directions_neon(int16x8_t lines[8],
+                                                uint32_t cost[4]) {
+  const int16x8_t zero = vdupq_n_s16(0);
 
-// Transpose and reverse the order of the lines -- equivalent to a 90-degree
-// counter-clockwise rotation of the pixels.
-static INLINE void array_reverse_transpose_8x8_neon(int16x8_t *in,
-                                                    int16x8_t *res) {
-  const int32x4_t tr0_0 = vreinterpretq_s32_s16(vzipq_s16(in[0], in[1]).val[0]);
-  const int32x4_t tr0_1 = vreinterpretq_s32_s16(vzipq_s16(in[2], in[3]).val[0]);
-  const int32x4_t tr0_2 = vreinterpretq_s32_s16(vzipq_s16(in[0], in[1]).val[1]);
-  const int32x4_t tr0_3 = vreinterpretq_s32_s16(vzipq_s16(in[2], in[3]).val[1]);
-  const int32x4_t tr0_4 = vreinterpretq_s32_s16(vzipq_s16(in[4], in[5]).val[0]);
-  const int32x4_t tr0_5 = vreinterpretq_s32_s16(vzipq_s16(in[6], in[7]).val[0]);
-  const int32x4_t tr0_6 = vreinterpretq_s32_s16(vzipq_s16(in[4], in[5]).val[1]);
-  const int32x4_t tr0_7 = vreinterpretq_s32_s16(vzipq_s16(in[6], in[7]).val[1]);
+  // Compute diagonal directions (1, 2, 3).
+  // Partial sums for lines 0 and 1.
+  int16x8_t partial0a = lines[0];
+  partial0a = vaddq_s16(partial0a, vextq_s16(zero, lines[1], 7));
+  int16x8_t partial0b = vextq_s16(lines[1], zero, 7);
+  int16x8_t partial1a = vaddq_s16(lines[0], vextq_s16(zero, lines[1], 6));
+  int16x8_t partial1b = vextq_s16(lines[1], zero, 6);
+  int16x8_t partial3a = vextq_s16(lines[0], zero, 2);
+  partial3a = vaddq_s16(partial3a, vextq_s16(lines[1], zero, 4));
+  int16x8_t partial3b = vextq_s16(zero, lines[0], 2);
+  partial3b = vaddq_s16(partial3b, vextq_s16(zero, lines[1], 4));
 
-  const int32x4_t tr1_0 = vzipq_s32(tr0_0, tr0_1).val[0];
-  const int32x4_t tr1_1 = vzipq_s32(tr0_4, tr0_5).val[0];
-  const int32x4_t tr1_2 = vzipq_s32(tr0_0, tr0_1).val[1];
-  const int32x4_t tr1_3 = vzipq_s32(tr0_4, tr0_5).val[1];
-  const int32x4_t tr1_4 = vzipq_s32(tr0_2, tr0_3).val[0];
-  const int32x4_t tr1_5 = vzipq_s32(tr0_6, tr0_7).val[0];
-  const int32x4_t tr1_6 = vzipq_s32(tr0_2, tr0_3).val[1];
-  const int32x4_t tr1_7 = vzipq_s32(tr0_6, tr0_7).val[1];
+  // Partial sums for lines 2 and 3.
+  partial0a = vaddq_s16(partial0a, vextq_s16(zero, lines[2], 6));
+  partial0a = vaddq_s16(partial0a, vextq_s16(zero, lines[3], 5));
+  partial0b = vaddq_s16(partial0b, vextq_s16(lines[2], zero, 6));
+  partial0b = vaddq_s16(partial0b, vextq_s16(lines[3], zero, 5));
+  partial1a = vaddq_s16(partial1a, vextq_s16(zero, lines[2], 4));
+  partial1a = vaddq_s16(partial1a, vextq_s16(zero, lines[3], 2));
+  partial1b = vaddq_s16(partial1b, vextq_s16(lines[2], zero, 4));
+  partial1b = vaddq_s16(partial1b, vextq_s16(lines[3], zero, 2));
+  partial3a = vaddq_s16(partial3a, vextq_s16(lines[2], zero, 6));
+  partial3b = vaddq_s16(partial3b, vextq_s16(zero, lines[2], 6));
+  partial3b = vaddq_s16(partial3b, lines[3]);
 
-  res[7] = vreinterpretq_s16_s64(ziplo_s64(tr1_0, tr1_1));
-  res[6] = vreinterpretq_s16_s64(ziphi_s64(tr1_0, tr1_1));
-  res[5] = vreinterpretq_s16_s64(ziplo_s64(tr1_2, tr1_3));
-  res[4] = vreinterpretq_s16_s64(ziphi_s64(tr1_2, tr1_3));
-  res[3] = vreinterpretq_s16_s64(ziplo_s64(tr1_4, tr1_5));
-  res[2] = vreinterpretq_s16_s64(ziphi_s64(tr1_4, tr1_5));
-  res[1] = vreinterpretq_s16_s64(ziplo_s64(tr1_6, tr1_7));
-  res[0] = vreinterpretq_s16_s64(ziphi_s64(tr1_6, tr1_7));
+  // Partial sums for lines 4 and 5.
+  partial0a = vaddq_s16(partial0a, vextq_s16(zero, lines[4], 4));
+  partial0a = vaddq_s16(partial0a, vextq_s16(zero, lines[5], 3));
+  partial0b = vaddq_s16(partial0b, vextq_s16(lines[4], zero, 4));
+  partial0b = vaddq_s16(partial0b, vextq_s16(lines[5], zero, 3));
+  partial1b = vaddq_s16(partial1b, lines[4]);
+  partial1b = vaddq_s16(partial1b, vextq_s16(zero, lines[5], 6));
+  int16x8_t partial1c = vextq_s16(lines[5], zero, 6);
+  partial3b = vaddq_s16(partial3b, vextq_s16(lines[4], zero, 2));
+  partial3b = vaddq_s16(partial3b, vextq_s16(lines[5], zero, 4));
+  int16x8_t partial3c = vextq_s16(zero, lines[4], 2);
+  partial3c = vaddq_s16(partial3c, vextq_s16(zero, lines[5], 4));
+
+  // Partial sums for lines 6 and 7.
+  partial0a = vaddq_s16(partial0a, vextq_s16(zero, lines[6], 2));
+  partial0a = vaddq_s16(partial0a, vextq_s16(zero, lines[7], 1));
+  partial0b = vaddq_s16(partial0b, vextq_s16(lines[6], zero, 2));
+  partial0b = vaddq_s16(partial0b, vextq_s16(lines[7], zero, 1));
+  partial1b = vaddq_s16(partial1b, vextq_s16(zero, lines[6], 4));
+  partial1b = vaddq_s16(partial1b, vextq_s16(zero, lines[7], 2));
+  partial1c = vaddq_s16(partial1c, vextq_s16(lines[6], zero, 4));
+  partial1c = vaddq_s16(partial1c, vextq_s16(lines[7], zero, 2));
+  partial3b = vaddq_s16(partial3b, vextq_s16(lines[6], zero, 6));
+  partial3c = vaddq_s16(partial3c, vextq_s16(zero, lines[6], 6));
+  partial3c = vaddq_s16(partial3c, lines[7]);
+
+  // Special case for direction 2 as it's just a sum along each line.
+  int16x8_t lines03[4] = { lines[0], lines[1], lines[2], lines[3] };
+  int16x8_t lines47[4] = { lines[4], lines[5], lines[6], lines[7] };
+  int32x4_t partial2a = horizontal_add_4d_s16x8(lines03);
+  int32x4_t partial2b = horizontal_add_4d_s16x8(lines47);
+
+  uint32x4_t partial2a_u32 =
+      vreinterpretq_u32_s32(vmulq_s32(partial2a, partial2a));
+  uint32x4_t partial2b_u32 =
+      vreinterpretq_u32_s32(vmulq_s32(partial2b, partial2b));
+
+  uint32x4_t const0 = vreinterpretq_u32_u64(
+      vcombine_u64(vcreate_u64((uint64_t)420 << 32 | 840),
+                   vcreate_u64((uint64_t)210 << 32 | 280)));
+  uint32x4_t const1 = vreinterpretq_u32_u64(
+      vcombine_u64(vcreate_u64((uint64_t)140 << 32 | 168),
+                   vcreate_u64((uint64_t)105 << 32 | 120)));
+  uint32x4_t const2 = vreinterpretq_u32_u64(
+      vcombine_u64(vcreate_u64((uint64_t)210 << 32 | 420),
+                   vcreate_u64((uint64_t)105 << 32 | 140)));
+
+  uint32x4_t costs[4];
+  costs[0] = fold_mul_and_sum_neon(partial0a, partial0b, const0, const1);
+  costs[1] =
+      fold_mul_and_sum_pairwise_neon(partial1a, partial1b, partial1c, const2);
+  costs[2] = vaddq_u32(partial2a_u32, partial2b_u32);
+  costs[2] = vmulq_n_u32(costs[2], 105);
+  costs[3] =
+      fold_mul_and_sum_pairwise_neon(partial3c, partial3b, partial3a, const2);
+
+  costs[0] = horizontal_add_4d_u32x4(costs);
+  vst1q_u32(cost, costs[0]);
+  return costs[0];
 }
 
 int cdef_find_dir_neon(const uint16_t *img, int stride, int32_t *var,
@@ -282,12 +392,10 @@
   }
 
   // Compute "mostly vertical" directions.
-  uint32x4_t cost47 = compute_directions_neon(lines, cost + 4);
-
-  array_reverse_transpose_8x8_neon(lines, lines);
+  uint32x4_t cost47 = compute_vert_directions_neon(lines, cost + 4);
 
   // Compute "mostly horizontal" directions.
-  uint32x4_t cost03 = compute_directions_neon(lines, cost);
+  uint32x4_t cost03 = compute_horiz_directions_neon(lines, cost);
 
   // Find max cost as well as its index to get best_dir.
   // The max cost needs to be propagated in the whole vector to find its