Optimize Armv8.0 Neon SAD4D 16xh, 32xh, 64xh and 128xh functions

Add a widening 4D reduction function operating on uint16x8_t vectors
and use it to optimize the final reduction in Armv8.0 Neon standard
bitdepth 32xh SAD4D computations.

In the 16xh, 64xh and 128xh SAD4D functions we use a combination of
widening pair-wise addition instructions before a final 4D reduction
operating on uint32x4_t vectors.

Change-Id: Ide120ecc71f706eef2a603958927edbcef929d4b
diff --git a/aom_dsp/arm/sad4d_neon.c b/aom_dsp/arm/sad4d_neon.c
index 81ec908..467f44c 100644
--- a/aom_dsp/arm/sad4d_neon.c
+++ b/aom_dsp/arm/sad4d_neon.c
@@ -207,7 +207,8 @@
 static INLINE void sad128xhx4d_neon(const uint8_t *src, int src_stride,
                                     const uint8_t *const ref[4], int ref_stride,
                                     uint32_t res[4], int h) {
-  vst1q_u32(res, vdupq_n_u32(0));
+  uint32x4_t sum[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
+                        vdupq_n_u32(0) };
   int h_tmp = h > 32 ? 32 : h;
 
   int i = 0;
@@ -269,19 +270,26 @@
       i++;
     } while (i < h_tmp);
 
-    res[0] += horizontal_long_add_u16x8(sum_lo[0], sum_hi[0]);
-    res[1] += horizontal_long_add_u16x8(sum_lo[1], sum_hi[1]);
-    res[2] += horizontal_long_add_u16x8(sum_lo[2], sum_hi[2]);
-    res[3] += horizontal_long_add_u16x8(sum_lo[3], sum_hi[3]);
+    sum[0] = vpadalq_u16(sum[0], sum_lo[0]);
+    sum[0] = vpadalq_u16(sum[0], sum_hi[0]);
+    sum[1] = vpadalq_u16(sum[1], sum_lo[1]);
+    sum[1] = vpadalq_u16(sum[1], sum_hi[1]);
+    sum[2] = vpadalq_u16(sum[2], sum_lo[2]);
+    sum[2] = vpadalq_u16(sum[2], sum_hi[2]);
+    sum[3] = vpadalq_u16(sum[3], sum_lo[3]);
+    sum[3] = vpadalq_u16(sum[3], sum_hi[3]);
 
     h_tmp += 32;
   } while (i < h);
+
+  vst1q_u32(res, horizontal_add_4d_u32x4(sum));
 }
 
 static INLINE void sad64xhx4d_neon(const uint8_t *src, int src_stride,
                                    const uint8_t *const ref[4], int ref_stride,
                                    uint32_t res[4], int h) {
-  vst1q_u32(res, vdupq_n_u32(0));
+  uint32x4_t sum[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
+                        vdupq_n_u32(0) };
   int h_tmp = h > 64 ? 64 : h;
 
   int i = 0;
@@ -319,13 +327,19 @@
       i++;
     } while (i < h_tmp);
 
-    res[0] += horizontal_long_add_u16x8(sum_lo[0], sum_hi[0]);
-    res[1] += horizontal_long_add_u16x8(sum_lo[1], sum_hi[1]);
-    res[2] += horizontal_long_add_u16x8(sum_lo[2], sum_hi[2]);
-    res[3] += horizontal_long_add_u16x8(sum_lo[3], sum_hi[3]);
+    sum[0] = vpadalq_u16(sum[0], sum_lo[0]);
+    sum[0] = vpadalq_u16(sum[0], sum_hi[0]);
+    sum[1] = vpadalq_u16(sum[1], sum_lo[1]);
+    sum[1] = vpadalq_u16(sum[1], sum_hi[1]);
+    sum[2] = vpadalq_u16(sum[2], sum_lo[2]);
+    sum[2] = vpadalq_u16(sum[2], sum_hi[2]);
+    sum[3] = vpadalq_u16(sum[3], sum_lo[3]);
+    sum[3] = vpadalq_u16(sum[3], sum_hi[3]);
 
     h_tmp += 64;
   } while (i < h);
+
+  vst1q_u32(res, horizontal_add_4d_u32x4(sum));
 }
 
 static INLINE void sad32xhx4d_neon(const uint8_t *src, int src_stride,
@@ -353,33 +367,33 @@
     i++;
   } while (i < h);
 
-  res[0] = horizontal_long_add_u16x8(sum_lo[0], sum_hi[0]);
-  res[1] = horizontal_long_add_u16x8(sum_lo[1], sum_hi[1]);
-  res[2] = horizontal_long_add_u16x8(sum_lo[2], sum_hi[2]);
-  res[3] = horizontal_long_add_u16x8(sum_lo[3], sum_hi[3]);
+  vst1q_u32(res, horizontal_long_add_4d_u16x8(sum_lo, sum_hi));
 }
 
 static INLINE void sad16xhx4d_neon(const uint8_t *src, int src_stride,
                                    const uint8_t *const ref[4], int ref_stride,
                                    uint32_t res[4], int h) {
-  uint16x8_t sum[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
-                        vdupq_n_u16(0) };
+  uint16x8_t sum_u16[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
+                            vdupq_n_u16(0) };
+  uint32x4_t sum_u32[4];
 
   int i = 0;
   do {
     const uint8x16_t s = vld1q_u8(src + i * src_stride);
-    sad16_neon(s, vld1q_u8(ref[0] + i * ref_stride), &sum[0]);
-    sad16_neon(s, vld1q_u8(ref[1] + i * ref_stride), &sum[1]);
-    sad16_neon(s, vld1q_u8(ref[2] + i * ref_stride), &sum[2]);
-    sad16_neon(s, vld1q_u8(ref[3] + i * ref_stride), &sum[3]);
+    sad16_neon(s, vld1q_u8(ref[0] + i * ref_stride), &sum_u16[0]);
+    sad16_neon(s, vld1q_u8(ref[1] + i * ref_stride), &sum_u16[1]);
+    sad16_neon(s, vld1q_u8(ref[2] + i * ref_stride), &sum_u16[2]);
+    sad16_neon(s, vld1q_u8(ref[3] + i * ref_stride), &sum_u16[3]);
 
     i++;
   } while (i < h);
 
-  res[0] = horizontal_add_u16x8(sum[0]);
-  res[1] = horizontal_add_u16x8(sum[1]);
-  res[2] = horizontal_add_u16x8(sum[2]);
-  res[3] = horizontal_add_u16x8(sum[3]);
+  sum_u32[0] = vpaddlq_u16(sum_u16[0]);
+  sum_u32[1] = vpaddlq_u16(sum_u16[1]);
+  sum_u32[2] = vpaddlq_u16(sum_u16[2]);
+  sum_u32[3] = vpaddlq_u16(sum_u16[3]);
+
+  vst1q_u32(res, horizontal_add_4d_u32x4(sum_u32));
 }
 
 #endif  // defined(__ARM_FEATURE_DOTPROD)
diff --git a/aom_dsp/arm/sum_neon.h b/aom_dsp/arm/sum_neon.h
index c0bfc69..dc0ea9f 100644
--- a/aom_dsp/arm/sum_neon.h
+++ b/aom_dsp/arm/sum_neon.h
@@ -97,6 +97,31 @@
 #endif
 }
 
+static INLINE uint32x4_t horizontal_long_add_4d_u16x8(
+    const uint16x8_t sum_lo[4], const uint16x8_t sum_hi[4]) {
+  const uint32x4_t a0 = vpaddlq_u16(sum_lo[0]);
+  const uint32x4_t a1 = vpaddlq_u16(sum_lo[1]);
+  const uint32x4_t a2 = vpaddlq_u16(sum_lo[2]);
+  const uint32x4_t a3 = vpaddlq_u16(sum_lo[3]);
+  const uint32x4_t b0 = vpadalq_u16(a0, sum_hi[0]);
+  const uint32x4_t b1 = vpadalq_u16(a1, sum_hi[1]);
+  const uint32x4_t b2 = vpadalq_u16(a2, sum_hi[2]);
+  const uint32x4_t b3 = vpadalq_u16(a3, sum_hi[3]);
+#if defined(__aarch64__)
+  const uint32x4_t c0 = vpaddq_u32(b0, b1);
+  const uint32x4_t c1 = vpaddq_u32(b2, b3);
+  return vpaddq_u32(c0, c1);
+#else
+  const uint32x2_t c0 = vadd_u32(vget_low_u32(b0), vget_high_u32(b0));
+  const uint32x2_t c1 = vadd_u32(vget_low_u32(b1), vget_high_u32(b1));
+  const uint32x2_t c2 = vadd_u32(vget_low_u32(b2), vget_high_u32(b2));
+  const uint32x2_t c3 = vadd_u32(vget_low_u32(b3), vget_high_u32(b3));
+  const uint32x2_t d0 = vpadd_u32(c0, c1);
+  const uint32x2_t d1 = vpadd_u32(c2, c3);
+  return vcombine_u32(d0, d1);
+#endif
+}
+
 static INLINE uint32_t horizontal_add_u16x8(const uint16x8_t a) {
 #if defined(__aarch64__)
   return vaddlvq_u16(a);