Simplify finding final direction in cdef_find_dir_neon

Simplify and optimize the final computation of cdef_find_dir_neon to
find the maximum cost and the associated index in the vector.

Change-Id: Id9821597b49c30689d63fff00ee1c748622faf69
diff --git a/aom_dsp/arm/sum_neon.h b/aom_dsp/arm/sum_neon.h
index b5a8b97..4261592 100644
--- a/aom_dsp/arm/sum_neon.h
+++ b/aom_dsp/arm/sum_neon.h
@@ -17,6 +17,16 @@
 #include "aom/aom_integer.h"
 #include "aom_ports/mem.h"
 
+static INLINE int horizontal_add_u8x8(const uint8x8_t a) {
+#if AOM_ARCH_AARCH64
+  return vaddlv_u8(a);
+#else
+  uint16x4_t b = vpaddl_u8(a);
+  uint32x2_t c = vpaddl_u16(b);
+  return vget_lane_u32(c, 0) + vget_lane_u32(c, 1);
+#endif
+}
+
 static INLINE int horizontal_add_s16x8(const int16x8_t a) {
 #if AOM_ARCH_AARCH64
   return vaddlvq_s16(a);
diff --git a/aom_ports/bitops.h b/aom_ports/bitops.h
index 7f4c165..a509628 100644
--- a/aom_ports/bitops.h
+++ b/aom_ports/bitops.h
@@ -13,6 +13,7 @@
 #define AOM_AOM_PORTS_BITOPS_H_
 
 #include <assert.h>
+#include <stdint.h>
 
 #include "aom_ports/msvc.h"
 #include "config/aom_config.h"
@@ -52,7 +53,6 @@
   _BitScanReverse(&first_set_bit, n);
   return first_set_bit;
 }
-#undef USE_MSC_INTRINSICS
 #else
 static INLINE int get_msb(unsigned int n) {
   int log = 0;
@@ -71,6 +71,32 @@
 }
 #endif
 
+#if defined(__GNUC__) && \
+    ((__GNUC__ == 3 && __GNUC_MINOR__ >= 4) || __GNUC__ >= 4)
+static INLINE int aom_clzll(uint64_t n) { return __builtin_clzll(n); }
+#elif defined(USE_MSC_INTRINSICS)
+#pragma intrinsic(_BitScanReverse64)
+
+static INLINE int aom_clzll(uint64_t n) {
+  int res;
+  _BitScanReverse64(&res, n);
+  return res;
+}
+#undef USE_MSC_INTRINSICS
+#else
+static INLINE int aom_clzll(uint64_t n) {
+  assert(n != 0);
+
+  int res = 0;
+  uint64_t high_bit = 1ULL << 63;
+  while (!(n & high_bit)) {
+    res++;
+    n <<= 1;
+  }
+  return res;
+}
+#endif
+
 #ifdef __cplusplus
 }  // extern "C"
 #endif
diff --git a/av1/common/arm/cdef_block_neon.c b/av1/common/arm/cdef_block_neon.c
index a6567fe..f69e9c4 100644
--- a/av1/common/arm/cdef_block_neon.c
+++ b/av1/common/arm/cdef_block_neon.c
@@ -16,6 +16,7 @@
 #include "config/av1_rtcd.h"
 
 #include "aom_dsp/arm/mem_neon.h"
+#include "aom_dsp/arm/sum_neon.h"
 #include "av1/common/cdef_block.h"
 
 void cdef_copy_rect8_8bit_to_16bit_neon(uint16_t *dst, int dstride,
@@ -363,19 +364,6 @@
   res[0] = vreinterpretq_s16_s64(ziphi_s64(tr1_6, tr1_7));
 }
 
-static INLINE uint32_t compute_best_dir(uint8x16_t a) {
-  uint8x16_t idx =
-      vandq_u8(a, vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201ULL)));
-#if AOM_ARCH_AARCH64
-  return vaddv_u8(vget_low_u8(idx)) + (vaddv_u8(vget_high_u8(idx)) << 8);
-#else
-  uint64x2_t m = vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(idx)));
-  uint8x16_t s = vreinterpretq_u8_u64(m);
-  return vget_lane_u32(
-      vreinterpret_u32_u8(vzip_u8(vget_low_u8(s), vget_high_u8(s)).val[0]), 0);
-#endif
-}
-
 int cdef_find_dir_neon(const uint16_t *img, int stride, int32_t *var,
                        int coeff_shift) {
   uint32_t cost[8];
@@ -396,15 +384,35 @@
   // Compute "mostly horizontal" directions.
   uint32x4_t cost03 = compute_directions_neon(lines, cost);
 
-  uint32x4_t max_cost = vmaxq_u32(cost03, cost47);
-  max_cost = vmaxq_u32(max_cost, vextq_u32(max_cost, max_cost, 2));
-  max_cost = vmaxq_u32(max_cost, vextq_u32(max_cost, max_cost, 1));
-  best_cost = vgetq_lane_u32(max_cost, 0);
-  uint16x8_t idx = vcombine_u16(vqmovn_u32(vceqq_u32(max_cost, cost03)),
-                                vqmovn_u32(vceqq_u32(max_cost, cost47)));
-  uint8x16_t idx_u8 = vcombine_u8(vqmovn_u16(idx), vqmovn_u16(idx));
-  best_dir = compute_best_dir(idx_u8);
-  best_dir = get_msb(best_dir ^ (best_dir - 1));  // Count trailing zeros
+  // Find max cost as well as its index to get best_dir.
+  // The max cost needs to be propagated in the whole vector to find its
+  // position in the original cost vectors cost03 and cost47.
+  uint32x4_t cost07 = vmaxq_u32(cost03, cost47);
+#if AOM_ARCH_AARCH64
+  best_cost = vmaxvq_u32(cost07);
+  uint32x4_t max_cost = vdupq_n_u32(best_cost);
+  uint8x16x2_t costs = { { vreinterpretq_u8_u32(vceqq_u32(max_cost, cost03)),
+                           vreinterpretq_u8_u32(
+                               vceqq_u32(max_cost, cost47)) } };
+  // idx = { 28, 24, 20, 16, 12, 8, 4, 0 };
+  uint8x8_t idx = vreinterpret_u8_u64(vcreate_u64(0x0004080c1014181cULL));
+  // Get the lowest 8 bit of each 32-bit elements and reverse them.
+  uint8x8_t tbl = vqtbl2_u8(costs, idx);
+  uint64_t a = vget_lane_u64(vreinterpret_u64_u8(tbl), 0);
+  best_dir = aom_clzll(a) >> 3;
+#else
+  uint32x2_t cost64 = vpmax_u32(vget_low_u32(cost07), vget_high_u32(cost07));
+  cost64 = vpmax_u32(cost64, cost64);
+  uint32x4_t max_cost = vcombine_u32(cost64, cost64);
+  best_cost = vget_lane_u32(cost64, 0);
+  uint16x8_t costs = vcombine_u16(vmovn_u32(vceqq_u32(max_cost, cost03)),
+                                  vmovn_u32(vceqq_u32(max_cost, cost47)));
+  uint8x8_t idx =
+      vand_u8(vmovn_u16(costs),
+              vreinterpret_u8_u64(vcreate_u64(0x8040201008040201ULL)));
+  int sum = horizontal_add_u8x8(idx);
+  best_dir = get_msb(sum ^ (sum - 1));
+#endif
 
   // Difference between the optimal variance and the variance along the
   // orthogonal direction. Again, the sum(x^2) terms cancel out.