Optimize Neon implementation of cdef filter functions

Optimize the constraint function, make the computation of max and min
more parallel and simplify the computation of the final result.

Change-Id: I6e1a7993285e74165929dfc8185cb4f397f8478b
diff --git a/av1/common/arm/cdef_block_neon.c b/av1/common/arm/cdef_block_neon.c
index 68a292b..69ea49f 100644
--- a/av1/common/arm/cdef_block_neon.c
+++ b/av1/common/arm/cdef_block_neon.c
@@ -450,36 +450,37 @@
 // sign(a-b) * min(abs(a-b), max(0, threshold - (abs(a-b) >> adjdamp)))
 static INLINE int16x8_t constrain16(uint16x8_t a, uint16x8_t b,
                                     unsigned int threshold, int adjdamp) {
-  int16x8_t diff = vreinterpretq_s16_u16(vsubq_u16(a, b));
-  const int16x8_t sign = vshrq_n_s16(diff, 15);
-  diff = vabsq_s16(diff);
-  const uint16x8_t s =
-      vqsubq_u16(vdupq_n_u16(threshold),
-                 vreinterpretq_u16_s16(vshlq_s16(diff, vdupq_n_s16(-adjdamp))));
-  return veorq_s16(vaddq_s16(sign, vminq_s16(diff, vreinterpretq_s16_u16(s))),
-                   sign);
+  uint16x8_t diff = vabdq_u16(a, b);
+  const uint16x8_t a_gt_b = vcgtq_u16(a, b);
+  const uint16x8_t s = vqsubq_u16(vdupq_n_u16(threshold),
+                                  vshlq_u16(diff, vdupq_n_s16(-adjdamp)));
+  const int16x8_t clip = vreinterpretq_s16_u16(vminq_u16(diff, s));
+  return vbslq_s16(a_gt_b, clip, vnegq_s16(clip));
 }
 
 static INLINE uint16x8_t get_max_primary(const int is_lowbd, uint16x8_t *tap,
                                          uint16x8_t max,
                                          uint16x8_t cdef_large_value_mask) {
   if (is_lowbd) {
-    uint8x16_t max_u8 = vreinterpretq_u8_u16(tap[0]);
-    max_u8 = vmaxq_u8(max_u8, vreinterpretq_u8_u16(tap[1]));
-    max_u8 = vmaxq_u8(max_u8, vreinterpretq_u8_u16(tap[2]));
-    max_u8 = vmaxq_u8(max_u8, vreinterpretq_u8_u16(tap[3]));
-    /* The source is 16 bits, however, we only really care about the lower
-    8 bits.  The upper 8 bits contain the "large" flag.  After the final
-    primary max has been calculated, zero out the upper 8 bits.  Use this
-    to find the "16 bit" max. */
+    // The source is 16 bits, however, we only really care about the lower
+    // 8 bits.  The upper 8 bits contain the "large" flag.  After the final
+    // primary max has been calculated, zero out the upper 8 bits.  Use this
+    // to find the "16 bit" max.
+    uint8x16_t max0 =
+        vmaxq_u8(vreinterpretq_u8_u16(tap[0]), vreinterpretq_u8_u16(tap[1]));
+    uint8x16_t max1 =
+        vmaxq_u8(vreinterpretq_u8_u16(tap[2]), vreinterpretq_u8_u16(tap[3]));
+    max0 = vmaxq_u8(max0, max1);
     max = vmaxq_u16(
-        max, vandq_u16(vreinterpretq_u16_u8(max_u8), cdef_large_value_mask));
+        max, vandq_u16(vreinterpretq_u16_u8(max0), cdef_large_value_mask));
   } else {
-    /* Convert CDEF_VERY_LARGE to 0 before calculating max. */
-    max = vmaxq_u16(max, vandq_u16(tap[0], cdef_large_value_mask));
-    max = vmaxq_u16(max, vandq_u16(tap[1], cdef_large_value_mask));
-    max = vmaxq_u16(max, vandq_u16(tap[2], cdef_large_value_mask));
-    max = vmaxq_u16(max, vandq_u16(tap[3], cdef_large_value_mask));
+    // Convert CDEF_VERY_LARGE to 0 before calculating max.
+    uint16x8_t max0 = vmaxq_u16(vandq_u16(tap[0], cdef_large_value_mask),
+                                vandq_u16(tap[1], cdef_large_value_mask));
+    uint16x8_t max1 = vmaxq_u16(vandq_u16(tap[2], cdef_large_value_mask),
+                                vandq_u16(tap[3], cdef_large_value_mask));
+    max0 = vmaxq_u16(max0, max1);
+    max = vmaxq_u16(max0, max);
   }
   return max;
 }
@@ -488,30 +489,37 @@
                                            uint16x8_t max,
                                            uint16x8_t cdef_large_value_mask) {
   if (is_lowbd) {
-    uint8x16_t max_u8 = vreinterpretq_u8_u16(tap[0]);
-    max_u8 = vmaxq_u8(max_u8, vreinterpretq_u8_u16(tap[1]));
-    max_u8 = vmaxq_u8(max_u8, vreinterpretq_u8_u16(tap[2]));
-    max_u8 = vmaxq_u8(max_u8, vreinterpretq_u8_u16(tap[3]));
-    max_u8 = vmaxq_u8(max_u8, vreinterpretq_u8_u16(tap[4]));
-    max_u8 = vmaxq_u8(max_u8, vreinterpretq_u8_u16(tap[5]));
-    max_u8 = vmaxq_u8(max_u8, vreinterpretq_u8_u16(tap[6]));
-    max_u8 = vmaxq_u8(max_u8, vreinterpretq_u8_u16(tap[7]));
-    /* The source is 16 bits, however, we only really care about the lower
-    8 bits.  The upper 8 bits contain the "large" flag.  After the final
-    primary max has been calculated, zero out the upper 8 bits.  Use this
-    to find the "16 bit" max. */
+    // The source is 16 bits, however, we only really care about the lower
+    // 8 bits.  The upper 8 bits contain the "large" flag.  After the final
+    // primary max has been calculated, zero out the upper 8 bits.  Use this
+    // to find the "16 bit" max.
+    uint8x16_t max0 =
+        vmaxq_u8(vreinterpretq_u8_u16(tap[0]), vreinterpretq_u8_u16(tap[1]));
+    uint8x16_t max1 =
+        vmaxq_u8(vreinterpretq_u8_u16(tap[2]), vreinterpretq_u8_u16(tap[3]));
+    uint8x16_t max2 =
+        vmaxq_u8(vreinterpretq_u8_u16(tap[4]), vreinterpretq_u8_u16(tap[5]));
+    uint8x16_t max3 =
+        vmaxq_u8(vreinterpretq_u8_u16(tap[6]), vreinterpretq_u8_u16(tap[7]));
+    max0 = vmaxq_u8(max0, max1);
+    max2 = vmaxq_u8(max2, max3);
+    max0 = vmaxq_u8(max0, max2);
     max = vmaxq_u16(
-        max, vandq_u16(vreinterpretq_u16_u8(max_u8), cdef_large_value_mask));
+        max, vandq_u16(vreinterpretq_u16_u8(max0), cdef_large_value_mask));
   } else {
-    /* Convert CDEF_VERY_LARGE to 0 before calculating max. */
-    max = vmaxq_u16(max, vandq_u16(tap[0], cdef_large_value_mask));
-    max = vmaxq_u16(max, vandq_u16(tap[1], cdef_large_value_mask));
-    max = vmaxq_u16(max, vandq_u16(tap[2], cdef_large_value_mask));
-    max = vmaxq_u16(max, vandq_u16(tap[3], cdef_large_value_mask));
-    max = vmaxq_u16(max, vandq_u16(tap[4], cdef_large_value_mask));
-    max = vmaxq_u16(max, vandq_u16(tap[5], cdef_large_value_mask));
-    max = vmaxq_u16(max, vandq_u16(tap[6], cdef_large_value_mask));
-    max = vmaxq_u16(max, vandq_u16(tap[7], cdef_large_value_mask));
+    // Convert CDEF_VERY_LARGE to 0 before calculating max.
+    uint16x8_t max0 = vmaxq_u16(vandq_u16(tap[0], cdef_large_value_mask),
+                                vandq_u16(tap[1], cdef_large_value_mask));
+    uint16x8_t max1 = vmaxq_u16(vandq_u16(tap[2], cdef_large_value_mask),
+                                vandq_u16(tap[3], cdef_large_value_mask));
+    uint16x8_t max2 = vmaxq_u16(vandq_u16(tap[4], cdef_large_value_mask),
+                                vandq_u16(tap[5], cdef_large_value_mask));
+    uint16x8_t max3 = vmaxq_u16(vandq_u16(tap[6], cdef_large_value_mask),
+                                vandq_u16(tap[7], cdef_large_value_mask));
+    max0 = vmaxq_u16(max0, max1);
+    max2 = vmaxq_u16(max2, max3);
+    max0 = vmaxq_u16(max0, max2);
+    max = vmaxq_u16(max, max0);
   }
   return max;
 }
@@ -576,10 +584,10 @@
       if (clipping_required) {
         max = get_max_primary(is_lowbd, tap, max, cdef_large_value_mask);
 
-        min = vminq_u16(min, tap[0]);
-        min = vminq_u16(min, tap[1]);
-        min = vminq_u16(min, tap[2]);
-        min = vminq_u16(min, tap[3]);
+        uint16x8_t min1 = vminq_u16(tap[0], tap[1]);
+        uint16x8_t min2 = vminq_u16(tap[2], tap[3]);
+        min1 = vminq_u16(min1, min2);
+        min = vminq_u16(min, min1);
       }
     }
 
@@ -621,22 +629,20 @@
       if (clipping_required) {
         max = get_max_secondary(is_lowbd, tap, max, cdef_large_value_mask);
 
-        min = vminq_u16(min, tap[0]);
-        min = vminq_u16(min, tap[1]);
-        min = vminq_u16(min, tap[2]);
-        min = vminq_u16(min, tap[3]);
-        min = vminq_u16(min, tap[4]);
-        min = vminq_u16(min, tap[5]);
-        min = vminq_u16(min, tap[6]);
-        min = vminq_u16(min, tap[7]);
+        uint16x8_t min0 = vminq_u16(tap[0], tap[1]);
+        uint16x8_t min1 = vminq_u16(tap[2], tap[3]);
+        uint16x8_t min2 = vminq_u16(tap[4], tap[5]);
+        uint16x8_t min3 = vminq_u16(tap[6], tap[7]);
+        min0 = vminq_u16(min0, min1);
+        min2 = vminq_u16(min2, min3);
+        min0 = vminq_u16(min0, min2);
+        min = vminq_u16(min, min0);
       }
     }
 
     // res = row + ((sum - (sum < 0) + 8) >> 4)
     sum = vaddq_s16(sum, vreinterpretq_s16_u16(vcltq_s16(sum, vdupq_n_s16(0))));
-    int16x8_t res = vaddq_s16(sum, vdupq_n_s16(8));
-    res = vshrq_n_s16(res, 4);
-    res = vaddq_s16(vreinterpretq_s16_u16(s), res);
+    int16x8_t res = vrsraq_n_s16(vreinterpretq_s16_u16(s), sum, 4);
 
     if (clipping_required) {
       res = vminq_s16(vmaxq_s16(res, vreinterpretq_s16_u16(min)),
@@ -716,10 +722,10 @@
       if (clipping_required) {
         max = get_max_primary(is_lowbd, tap, max, cdef_large_value_mask);
 
-        min = vminq_u16(min, tap[0]);
-        min = vminq_u16(min, tap[1]);
-        min = vminq_u16(min, tap[2]);
-        min = vminq_u16(min, tap[3]);
+        uint16x8_t min0 = vminq_u16(tap[0], tap[1]);
+        uint16x8_t min1 = vminq_u16(tap[2], tap[3]);
+        min0 = vminq_u16(min0, min1);
+        min = vminq_u16(min, min0);
       }
     }
 
@@ -761,22 +767,21 @@
       if (clipping_required) {
         max = get_max_secondary(is_lowbd, tap, max, cdef_large_value_mask);
 
-        min = vminq_u16(min, tap[0]);
-        min = vminq_u16(min, tap[1]);
-        min = vminq_u16(min, tap[2]);
-        min = vminq_u16(min, tap[3]);
-        min = vminq_u16(min, tap[4]);
-        min = vminq_u16(min, tap[5]);
-        min = vminq_u16(min, tap[6]);
-        min = vminq_u16(min, tap[7]);
+        uint16x8_t min0 = vminq_u16(tap[0], tap[1]);
+        uint16x8_t min1 = vminq_u16(tap[2], tap[3]);
+        uint16x8_t min2 = vminq_u16(tap[4], tap[5]);
+        uint16x8_t min3 = vminq_u16(tap[6], tap[7]);
+        min0 = vminq_u16(min0, min1);
+        min2 = vminq_u16(min2, min3);
+        min0 = vminq_u16(min0, min2);
+        min = vminq_u16(min, min0);
       }
     }
 
     // res = row + ((sum - (sum < 0) + 8) >> 4)
     sum = vaddq_s16(sum, vreinterpretq_s16_u16(vcltq_s16(sum, vdupq_n_s16(0))));
-    int16x8_t res = vaddq_s16(sum, vdupq_n_s16(8));
-    res = vshrq_n_s16(res, 4);
-    res = vaddq_s16(vreinterpretq_s16_u16(s), res);
+    int16x8_t res = vrsraq_n_s16(vreinterpretq_s16_u16(s), sum, 4);
+
     if (clipping_required) {
       res = vminq_s16(vmaxq_s16(res, vreinterpretq_s16_u16(min)),
                       vreinterpretq_s16_u16(max));