Optimize Neon implementation of aom_lpf_14

The Neon implementations of aom_lpf_vertical_14 and
aom_lpf_horizontal_14 compute every filter (filter4(), filter8() and
filter14()) before selecting for each element which filter is actually
needed. In practice, however, a lot of cases only need one of the
filters, so specialize for these scenarios, computing only the filters
that are needed and eliminate bitwise select.  This makes the case where
all filters are needed slightly slower, but as it is far from the most
common case this is ok.

Also move the actual filter computation to separate functions to avoid
code duplication.

This is a port from 4d6f560d1f4a35a47e5dfc9b7c2dc4f5483d49dc in SVT-AV1.

Change-Id: Id608edad6fee69fc263eb6ee9e0956ca6311d77f
diff --git a/aom_dsp/arm/loopfilter_neon.c b/aom_dsp/arm/loopfilter_neon.c
index 7822a14..fe201bc 100644
--- a/aom_dsp/arm/loopfilter_neon.c
+++ b/aom_dsp/arm/loopfilter_neon.c
@@ -146,192 +146,6 @@
   return mask_8x8;
 }
 
-static void lpf_14_neon(uint8x8_t *p6q6, uint8x8_t *p5q5, uint8x8_t *p4q4,
-                        uint8x8_t *p3q3, uint8x8_t *p2q2, uint8x8_t *p1q1,
-                        uint8x8_t *p0q0, const uint8_t blimit,
-                        const uint8_t limit, const uint8_t thresh) {
-  uint16x8_t out;
-  uint8x8_t out_f14_pq0, out_f14_pq1, out_f14_pq2, out_f14_pq3, out_f14_pq4,
-      out_f14_pq5;
-  uint8x8_t out_f7_pq0, out_f7_pq1, out_f7_pq2;
-  uint8x8_t out_f4_pq0, out_f4_pq1;
-  uint8x8_t mask_8x8, flat_8x8, flat2_8x8;
-  uint8x8_t q0p0, q1p1, q2p2;
-
-  // Calculate filter masks
-  mask_8x8 = lpf_mask(*p3q3, *p2q2, *p1q1, *p0q0, blimit, limit);
-  flat_8x8 = lpf_flat_mask4(*p3q3, *p2q2, *p1q1, *p0q0);
-  flat2_8x8 = lpf_flat_mask4(*p6q6, *p5q5, *p4q4, *p0q0);
-  {
-    // filter 4
-    int32x2x2_t ps0_qs0, ps1_qs1;
-    int16x8_t filter_s16;
-    const uint8x8_t thresh_f4 = vdup_n_u8(thresh);
-    uint8x8_t temp0_8x8, temp1_8x8;
-    int8x8_t ps0_s8, ps1_s8, qs0_s8, qs1_s8, temp_s8;
-    int8x8_t op0, oq0, op1, oq1;
-    int8x8_t pq_s0, pq_s1;
-    int8x8_t filter_s8, filter1_s8, filter2_s8;
-    int8x8_t hev_8x8;
-    const int8x8_t sign_mask = vdup_n_s8(0x80);
-    const int8x8_t val_4 = vdup_n_s8(4);
-    const int8x8_t val_3 = vdup_n_s8(3);
-
-    pq_s0 = veor_s8(vreinterpret_s8_u8(*p0q0), sign_mask);
-    pq_s1 = veor_s8(vreinterpret_s8_u8(*p1q1), sign_mask);
-
-    ps0_qs0 = vtrn_s32(vreinterpret_s32_s8(pq_s0), vreinterpret_s32_s8(pq_s0));
-    ps1_qs1 = vtrn_s32(vreinterpret_s32_s8(pq_s1), vreinterpret_s32_s8(pq_s1));
-    ps0_s8 = vreinterpret_s8_s32(ps0_qs0.val[0]);
-    qs0_s8 = vreinterpret_s8_s32(ps0_qs0.val[1]);
-    ps1_s8 = vreinterpret_s8_s32(ps1_qs1.val[0]);
-    qs1_s8 = vreinterpret_s8_s32(ps1_qs1.val[1]);
-
-    // hev_mask
-    temp0_8x8 = vcgt_u8(vabd_u8(*p0q0, *p1q1), thresh_f4);
-    temp1_8x8 = vreinterpret_u8_u32(vrev64_u32(vreinterpret_u32_u8(temp0_8x8)));
-    hev_8x8 = vreinterpret_s8_u8(vorr_u8(temp0_8x8, temp1_8x8));
-
-    // add outer taps if we have high edge variance
-    filter_s8 = vqsub_s8(ps1_s8, qs1_s8);
-    filter_s8 = vand_s8(filter_s8, hev_8x8);
-
-    // inner taps
-    temp_s8 = vqsub_s8(qs0_s8, ps0_s8);
-    filter_s16 = vmovl_s8(filter_s8);
-    filter_s16 = vmlal_s8(filter_s16, temp_s8, val_3);
-    filter_s8 = vqmovn_s16(filter_s16);
-    filter_s8 = vand_s8(filter_s8, vreinterpret_s8_u8(mask_8x8));
-
-    filter1_s8 = vqadd_s8(filter_s8, val_4);
-    filter2_s8 = vqadd_s8(filter_s8, val_3);
-    filter1_s8 = vshr_n_s8(filter1_s8, 3);
-    filter2_s8 = vshr_n_s8(filter2_s8, 3);
-
-    oq0 = veor_s8(vqsub_s8(qs0_s8, filter1_s8), sign_mask);
-    op0 = veor_s8(vqadd_s8(ps0_s8, filter2_s8), sign_mask);
-
-    hev_8x8 = vmvn_s8(hev_8x8);
-    filter_s8 = vrshr_n_s8(filter1_s8, 1);
-    filter_s8 = vand_s8(filter_s8, hev_8x8);
-
-    oq1 = veor_s8(vqsub_s8(qs1_s8, filter_s8), sign_mask);
-    op1 = veor_s8(vqadd_s8(ps1_s8, filter_s8), sign_mask);
-
-    out_f4_pq0 = vreinterpret_u8_s8(vext_s8(op0, oq0, 4));
-    out_f4_pq1 = vreinterpret_u8_s8(vext_s8(op1, oq1, 4));
-  }
-  // reverse p and q
-  q0p0 = vreinterpret_u8_u32(vrev64_u32(vreinterpret_u32_u8(*p0q0)));
-  q1p1 = vreinterpret_u8_u32(vrev64_u32(vreinterpret_u32_u8(*p1q1)));
-  q2p2 = vreinterpret_u8_u32(vrev64_u32(vreinterpret_u32_u8(*p2q2)));
-  {
-    // filter 8
-    uint16x8_t out_pq0, out_pq1, out_pq2;
-    out = vaddl_u8(*p3q3, *p2q2);
-    out = vaddw_u8(out, *p1q1);
-    out = vaddw_u8(out, *p0q0);
-
-    out = vaddw_u8(out, q0p0);
-    out_pq1 = vaddw_u8(out, *p3q3);
-    out_pq2 = vaddw_u8(out_pq1, *p3q3);
-    out_pq2 = vaddw_u8(out_pq2, *p2q2);
-    out_pq1 = vaddw_u8(out_pq1, *p1q1);
-    out_pq1 = vaddw_u8(out_pq1, q1p1);
-
-    out_pq0 = vaddw_u8(out, *p0q0);
-    out_pq0 = vaddw_u8(out_pq0, q1p1);
-    out_pq0 = vaddw_u8(out_pq0, q2p2);
-
-    out_f7_pq0 = vrshrn_n_u16(out_pq0, 3);
-    out_f7_pq1 = vrshrn_n_u16(out_pq1, 3);
-    out_f7_pq2 = vrshrn_n_u16(out_pq2, 3);
-  }
-  {
-    // filter 14
-    uint16x8_t out_pq0, out_pq1, out_pq2, out_pq3, out_pq4, out_pq5;
-    uint16x8_t p6q6_2, p6q6_temp, qp_sum;
-    uint8x8_t qp_rev;
-
-    out = vaddw_u8(out, *p4q4);
-    out = vaddw_u8(out, *p5q5);
-    out = vaddw_u8(out, *p6q6);
-
-    out_pq5 = vaddw_u8(out, *p4q4);
-    out_pq4 = vaddw_u8(out_pq5, *p3q3);
-    out_pq3 = vaddw_u8(out_pq4, *p2q2);
-
-    out_pq5 = vaddw_u8(out_pq5, *p5q5);
-    out_pq4 = vaddw_u8(out_pq4, *p5q5);
-
-    out_pq0 = vaddw_u8(out, *p1q1);
-    out_pq1 = vaddw_u8(out_pq0, *p2q2);
-    out_pq2 = vaddw_u8(out_pq1, *p3q3);
-
-    out_pq0 = vaddw_u8(out_pq0, *p0q0);
-    out_pq1 = vaddw_u8(out_pq1, *p0q0);
-
-    out_pq1 = vaddw_u8(out_pq1, *p6q6);
-    p6q6_2 = vaddl_u8(*p6q6, *p6q6);
-    out_pq2 = vaddq_u16(out_pq2, p6q6_2);
-    p6q6_temp = vaddw_u8(p6q6_2, *p6q6);
-    out_pq3 = vaddq_u16(out_pq3, p6q6_temp);
-    p6q6_temp = vaddw_u8(p6q6_temp, *p6q6);
-    out_pq4 = vaddq_u16(out_pq4, p6q6_temp);
-    p6q6_temp = vaddq_u16(p6q6_temp, p6q6_2);
-    out_pq5 = vaddq_u16(out_pq5, p6q6_temp);
-
-    out_pq4 = vaddw_u8(out_pq4, q1p1);
-
-    qp_sum = vaddl_u8(q2p2, q1p1);
-    out_pq3 = vaddq_u16(out_pq3, qp_sum);
-
-    qp_rev = vreinterpret_u8_u32(vrev64_u32(vreinterpret_u32_u8(*p3q3)));
-    qp_sum = vaddw_u8(qp_sum, qp_rev);
-    out_pq2 = vaddq_u16(out_pq2, qp_sum);
-
-    qp_rev = vreinterpret_u8_u32(vrev64_u32(vreinterpret_u32_u8(*p4q4)));
-    qp_sum = vaddw_u8(qp_sum, qp_rev);
-    out_pq1 = vaddq_u16(out_pq1, qp_sum);
-
-    qp_rev = vreinterpret_u8_u32(vrev64_u32(vreinterpret_u32_u8(*p5q5)));
-    qp_sum = vaddw_u8(qp_sum, qp_rev);
-    out_pq0 = vaddq_u16(out_pq0, qp_sum);
-
-    out_pq0 = vaddw_u8(out_pq0, q0p0);
-
-    out_f14_pq0 = vrshrn_n_u16(out_pq0, 4);
-    out_f14_pq1 = vrshrn_n_u16(out_pq1, 4);
-    out_f14_pq2 = vrshrn_n_u16(out_pq2, 4);
-    out_f14_pq3 = vrshrn_n_u16(out_pq3, 4);
-    out_f14_pq4 = vrshrn_n_u16(out_pq4, 4);
-    out_f14_pq5 = vrshrn_n_u16(out_pq5, 4);
-  }
-  {
-    uint8x8_t filter4_cond, filter8_cond, filter14_cond;
-    filter8_cond = vand_u8(flat_8x8, mask_8x8);
-    filter4_cond = vmvn_u8(filter8_cond);
-    filter14_cond = vand_u8(filter8_cond, flat2_8x8);
-
-    // filter4 outputs
-    *p0q0 = vbsl_u8(filter4_cond, out_f4_pq0, *p0q0);
-    *p1q1 = vbsl_u8(filter4_cond, out_f4_pq1, *p1q1);
-
-    // filter8 outputs
-    *p0q0 = vbsl_u8(filter8_cond, out_f7_pq0, *p0q0);
-    *p1q1 = vbsl_u8(filter8_cond, out_f7_pq1, *p1q1);
-    *p2q2 = vbsl_u8(filter8_cond, out_f7_pq2, *p2q2);
-
-    // filter14 outputs
-    *p0q0 = vbsl_u8(filter14_cond, out_f14_pq0, *p0q0);
-    *p1q1 = vbsl_u8(filter14_cond, out_f14_pq1, *p1q1);
-    *p2q2 = vbsl_u8(filter14_cond, out_f14_pq2, *p2q2);
-    *p3q3 = vbsl_u8(filter14_cond, out_f14_pq3, *p3q3);
-    *p4q4 = vbsl_u8(filter14_cond, out_f14_pq4, *p4q4);
-    *p5q5 = vbsl_u8(filter14_cond, out_f14_pq5, *p5q5);
-  }
-}
-
 static inline void filter4(const uint8x8_t p0q0, const uint8x8_t p1q1,
                            uint8x8_t *p0q0_output, uint8x8_t *p1q1_output,
                            uint8x8_t mask_8x8, const uint8_t thresh) {
@@ -418,6 +232,173 @@
   *p2q2_output = vrshrn_n_u16(out_pq2, 3);
 }
 
+static inline void filter14(const uint8x8_t p0q0, const uint8x8_t p1q1,
+                            const uint8x8_t p2q2, const uint8x8_t p3q3,
+                            const uint8x8_t p4q4, const uint8x8_t p5q5,
+                            const uint8x8_t p6q6, uint8x8_t *p0q0_output,
+                            uint8x8_t *p1q1_output, uint8x8_t *p2q2_output,
+                            uint8x8_t *p3q3_output, uint8x8_t *p4q4_output,
+                            uint8x8_t *p5q5_output) {
+  // Reverse p and q.
+  uint8x8_t q0p0 = vext_u8(p0q0, p0q0, 4);
+  uint8x8_t q1p1 = vext_u8(p1q1, p1q1, 4);
+  uint8x8_t q2p2 = vext_u8(p2q2, p2q2, 4);
+  uint8x8_t q3p3 = vext_u8(p3q3, p3q3, 4);
+  uint8x8_t q4p4 = vext_u8(p4q4, p4q4, 4);
+  uint8x8_t q5p5 = vext_u8(p5q5, p5q5, 4);
+
+  uint16x8_t p0q0_p1q1 = vaddl_u8(p0q0, p1q1);
+  uint16x8_t p2q2_p3q3 = vaddl_u8(p2q2, p3q3);
+  uint16x8_t out = vaddq_u16(p0q0_p1q1, p2q2_p3q3);
+
+  uint16x8_t q0p0_p4q4 = vaddl_u8(q0p0, p4q4);
+  uint16x8_t p5q5_p6q6 = vaddl_u8(p5q5, p6q6);
+  uint16x8_t tmp = vaddq_u16(q0p0_p4q4, p5q5_p6q6);
+  // This offset removes the need for a rounding shift at the end.
+  uint16x8_t tmp_offset = vaddq_u16(tmp, vdupq_n_u16(1 << 3));
+  out = vaddq_u16(out, tmp_offset);
+
+  uint16x8_t out_pq5 = vaddw_u8(out, p4q4);
+  uint16x8_t out_pq4 = vaddw_u8(out_pq5, p3q3);
+  uint16x8_t out_pq3 = vaddw_u8(out_pq4, p2q2);
+
+  out_pq5 = vaddw_u8(out_pq5, p5q5);
+
+  uint16x8_t out_pq0 = vaddw_u8(out, p1q1);
+  uint16x8_t out_pq1 = vaddw_u8(out_pq0, p2q2);
+  uint16x8_t out_pq2 = vaddw_u8(out_pq1, p3q3);
+
+  uint16x8_t p0q0_q0p0 = vaddl_u8(p0q0, q0p0);
+  out_pq0 = vaddq_u16(out_pq0, p0q0_q0p0);
+
+  uint16x8_t p0q0_p6q6 = vaddl_u8(p0q0, p6q6);
+  out_pq1 = vaddq_u16(out_pq1, p0q0_p6q6);
+  uint16x8_t p5q5_q1p1 = vaddl_u8(p5q5, q1p1);
+  out_pq4 = vaddq_u16(out_pq4, p5q5_q1p1);
+
+  uint16x8_t p6q6_p6q6 = vaddl_u8(p6q6, p6q6);
+  out_pq2 = vaddq_u16(out_pq2, p6q6_p6q6);
+  uint16x8_t p6q6_temp = vaddw_u8(p6q6_p6q6, p6q6);
+  out_pq3 = vaddq_u16(out_pq3, p6q6_temp);
+  p6q6_temp = vaddw_u8(p6q6_temp, p6q6);
+  out_pq4 = vaddq_u16(out_pq4, p6q6_temp);
+  p6q6_temp = vaddq_u16(p6q6_temp, p6q6_p6q6);
+  out_pq5 = vaddq_u16(out_pq5, p6q6_temp);
+
+  uint16x8_t qp_sum = vaddl_u8(q2p2, q1p1);
+  out_pq3 = vaddq_u16(out_pq3, qp_sum);
+
+  qp_sum = vaddw_u8(qp_sum, q3p3);
+  out_pq2 = vaddq_u16(out_pq2, qp_sum);
+
+  qp_sum = vaddw_u8(qp_sum, q4p4);
+  out_pq1 = vaddq_u16(out_pq1, qp_sum);
+
+  qp_sum = vaddw_u8(qp_sum, q5p5);
+  out_pq0 = vaddq_u16(out_pq0, qp_sum);
+
+  *p0q0_output = vshrn_n_u16(out_pq0, 4);
+  *p1q1_output = vshrn_n_u16(out_pq1, 4);
+  *p2q2_output = vshrn_n_u16(out_pq2, 4);
+  *p3q3_output = vshrn_n_u16(out_pq3, 4);
+  *p4q4_output = vshrn_n_u16(out_pq4, 4);
+  *p5q5_output = vshrn_n_u16(out_pq5, 4);
+}
+
+static inline void lpf_14_neon(uint8x8_t *p6q6, uint8x8_t *p5q5,
+                               uint8x8_t *p4q4, uint8x8_t *p3q3,
+                               uint8x8_t *p2q2, uint8x8_t *p1q1,
+                               uint8x8_t *p0q0, const uint8_t blimit,
+                               const uint8_t limit, const uint8_t thresh) {
+  uint8x8_t out_f14_pq0, out_f14_pq1, out_f14_pq2, out_f14_pq3, out_f14_pq4,
+      out_f14_pq5;
+  uint8x8_t out_f7_pq0, out_f7_pq1, out_f7_pq2;
+  uint8x8_t out_f4_pq0, out_f4_pq1;
+
+  // Calculate filter masks.
+  uint8x8_t mask_8x8 = lpf_mask(*p3q3, *p2q2, *p1q1, *p0q0, blimit, limit);
+  uint8x8_t flat_8x8 = lpf_flat_mask4(*p3q3, *p2q2, *p1q1, *p0q0);
+  uint8x8_t flat2_8x8 = lpf_flat_mask4(*p6q6, *p5q5, *p4q4, *p0q0);
+
+  // No filtering.
+  if (vget_lane_u64(vreinterpret_u64_u8(mask_8x8), 0) == 0) {
+    return;
+  }
+
+  uint8x8_t filter8_cond = vand_u8(flat_8x8, mask_8x8);
+  uint8x8_t filter4_cond = vmvn_u8(filter8_cond);
+  uint8x8_t filter14_cond = vand_u8(filter8_cond, flat2_8x8);
+
+  if (vget_lane_s64(vreinterpret_s64_u8(filter14_cond), 0) == -1) {
+    // Only filter14() applies.
+    filter14(*p0q0, *p1q1, *p2q2, *p3q3, *p4q4, *p5q5, *p6q6, &out_f14_pq0,
+             &out_f14_pq1, &out_f14_pq2, &out_f14_pq3, &out_f14_pq4,
+             &out_f14_pq5);
+
+    *p0q0 = out_f14_pq0;
+    *p1q1 = out_f14_pq1;
+    *p2q2 = out_f14_pq2;
+    *p3q3 = out_f14_pq3;
+    *p4q4 = out_f14_pq4;
+    *p5q5 = out_f14_pq5;
+  } else if (vget_lane_u64(vreinterpret_u64_u8(filter14_cond), 0) == 0 &&
+             vget_lane_s64(vreinterpret_s64_u8(filter8_cond), 0) == -1) {
+    // Only filter8() applies.
+    filter8(*p0q0, *p1q1, *p2q2, *p3q3, &out_f7_pq0, &out_f7_pq1, &out_f7_pq2);
+
+    *p0q0 = out_f7_pq0;
+    *p1q1 = out_f7_pq1;
+    *p2q2 = out_f7_pq2;
+  } else {
+    filter4(*p0q0, *p1q1, &out_f4_pq0, &out_f4_pq1, mask_8x8, thresh);
+
+    if (vget_lane_u64(vreinterpret_u64_u8(filter14_cond), 0) == 0 &&
+        vget_lane_u64(vreinterpret_u64_u8(filter8_cond), 0) == 0) {
+      // filter8() and filter14() do not apply, but filter4() applies to one or
+      // more values.
+      *p0q0 = vbsl_u8(filter4_cond, out_f4_pq0, *p0q0);
+      *p1q1 = vbsl_u8(filter4_cond, out_f4_pq1, *p1q1);
+    } else {
+      filter8(*p0q0, *p1q1, *p2q2, *p3q3, &out_f7_pq0, &out_f7_pq1,
+              &out_f7_pq2);
+
+      if (vget_lane_u64(vreinterpret_u64_u8(filter14_cond), 0) == 0) {
+        // filter14() does not apply, but filter8() and filter4() apply to one
+        // or more values. filter4 outputs
+        *p0q0 = vbsl_u8(filter4_cond, out_f4_pq0, *p0q0);
+        *p1q1 = vbsl_u8(filter4_cond, out_f4_pq1, *p1q1);
+
+        // filter8 outputs
+        *p0q0 = vbsl_u8(filter8_cond, out_f7_pq0, *p0q0);
+        *p1q1 = vbsl_u8(filter8_cond, out_f7_pq1, *p1q1);
+        *p2q2 = vbsl_u8(filter8_cond, out_f7_pq2, *p2q2);
+      } else {
+        // All filters may contribute values to final outputs.
+        filter14(*p0q0, *p1q1, *p2q2, *p3q3, *p4q4, *p5q5, *p6q6, &out_f14_pq0,
+                 &out_f14_pq1, &out_f14_pq2, &out_f14_pq3, &out_f14_pq4,
+                 &out_f14_pq5);
+
+        // filter4 outputs
+        *p0q0 = vbsl_u8(filter4_cond, out_f4_pq0, *p0q0);
+        *p1q1 = vbsl_u8(filter4_cond, out_f4_pq1, *p1q1);
+
+        // filter8 outputs
+        *p0q0 = vbsl_u8(filter8_cond, out_f7_pq0, *p0q0);
+        *p1q1 = vbsl_u8(filter8_cond, out_f7_pq1, *p1q1);
+        *p2q2 = vbsl_u8(filter8_cond, out_f7_pq2, *p2q2);
+
+        // filter14 outputs
+        *p0q0 = vbsl_u8(filter14_cond, out_f14_pq0, *p0q0);
+        *p1q1 = vbsl_u8(filter14_cond, out_f14_pq1, *p1q1);
+        *p2q2 = vbsl_u8(filter14_cond, out_f14_pq2, *p2q2);
+        *p3q3 = vbsl_u8(filter14_cond, out_f14_pq3, *p3q3);
+        *p4q4 = vbsl_u8(filter14_cond, out_f14_pq4, *p4q4);
+        *p5q5 = vbsl_u8(filter14_cond, out_f14_pq5, *p5q5);
+      }
+    }
+  }
+}
+
 static inline void lpf_8_neon(uint8x8_t *p3q3, uint8x8_t *p2q2, uint8x8_t *p1q1,
                               uint8x8_t *p0q0, const uint8_t blimit,
                               const uint8_t limit, const uint8_t thresh) {