intrapred_neon: update aom_smooth_predictor_{16,32,64}xN_neon

to the version based on libgav1 @ v0.17.0-83-ge54abf5c

~2.05-2.34x faster depending on the size

Bug: b/231616924
Change-Id: I9527244a9f6d6a2ecd678fd425f459512cbdb9a3
diff --git a/aom_dsp/arm/intrapred_neon.c b/aom_dsp/arm/intrapred_neon.c
index dfd56a3..bdc6a29 100644
--- a/aom_dsp/arm/intrapred_neon.c
+++ b/aom_dsp/arm/intrapred_neon.c
@@ -2753,130 +2753,135 @@
 
 #undef SMOOTH_NXM
 
-static INLINE void smooth_predictor_wxh(uint8_t *dst, ptrdiff_t stride,
-                                        const uint8_t *above,
-                                        const uint8_t *left, uint32_t bw,
-                                        uint32_t bh) {
-  const uint8_t *const sm_weights_w = smooth_weights + bw - 4;
-  const uint8_t *const sm_weights_h = smooth_weights + bh - 4;
-  const uint16x8_t scale_value = vdupq_n_u16(256);
+static INLINE uint8x16_t calculate_weights_and_predq(
+    const uint8x16_t top, const uint8x8_t left, const uint8x8_t top_right,
+    const uint8x8_t weights_y, const uint8x16_t weights_x,
+    const uint8x16_t scaled_weights_x, const uint16x8_t weighted_bl) {
+  const uint16x8_t weighted_top_bl_low =
+      vmlal_u8(weighted_bl, weights_y, vget_low_u8(top));
+  const uint16x8_t weighted_left_low = vmull_u8(vget_low_u8(weights_x), left);
+  const uint16x8_t weighted_left_tr_low =
+      vmlal_u8(weighted_left_low, vget_low_u8(scaled_weights_x), top_right);
+  const uint8x8_t result_low =
+      calculate_pred(weighted_top_bl_low, weighted_left_tr_low);
 
-  for (uint32_t y = 0; y < bh; ++y) {
-    const uint8x8_t left_y = vdup_n_u8(left[y]);
-    const uint8x8_t weights_y_dup = vdup_n_u8(sm_weights_h[y]);
-    const uint32x4_t pred_scaled_bl =
-        vdupq_n_u32(256 + (256 - sm_weights_h[y]) * left[bh - 1]);
+  const uint16x8_t weighted_top_bl_high =
+      vmlal_u8(weighted_bl, weights_y, vget_high_u8(top));
+  const uint16x8_t weighted_left_high = vmull_u8(vget_high_u8(weights_x), left);
+  const uint16x8_t weighted_left_tr_high =
+      vmlal_u8(weighted_left_high, vget_high_u8(scaled_weights_x), top_right);
+  const uint8x8_t result_high =
+      calculate_pred(weighted_top_bl_high, weighted_left_tr_high);
 
-    for (uint32_t x = 0; x < bw; x += 8) {
-      const uint8x8_t weights_x = vld1_u8(sm_weights_w + x);
-      const uint8x8_t top_x = vld1_u8(above + x);
+  return vcombine_u8(result_low, result_high);
+}
 
-      uint16x8_t pred_m1, pred_m2;
-      uint32x4_t pred_lo, pred_hi;
-      pred_m1 = vmull_u8(top_x, weights_y_dup);
-      pred_m2 = vmull_u8(weights_x, left_y);
+// 256 - v = vneg_s8(v)
+static INLINE uint8x16_t negate_s8q(const uint8x16_t v) {
+  return vreinterpretq_u8_s8(vnegq_s8(vreinterpretq_s8_u8(v)));
+}
 
-      pred_lo = vaddl_u16(vget_low_u16(pred_m1), vget_low_u16(pred_m2));
-#if defined(__aarch64__)
-      pred_hi = vaddl_high_u16(pred_m1, pred_m2);
-#else
-      pred_hi = vaddl_u16(vget_high_u16(pred_m1), vget_high_u16(pred_m2));
-#endif  // (__aarch64__)
-
-      const uint16x8_t scale_m_weights_x = vsubw_u8(scale_value, weights_x);
-
-      const uint16x8_t swxtr = vmulq_n_u16(scale_m_weights_x, above[bw - 1]);
-
-      pred_lo = vaddq_u32(pred_lo, pred_scaled_bl);
-      pred_hi = vaddq_u32(pred_hi, pred_scaled_bl);
-
-      pred_lo = vaddw_u16(pred_lo, vget_low_u16(swxtr));
-#if defined(__aarch64__)
-      pred_hi = vaddw_high_u16(pred_hi, swxtr);
-#else
-      pred_hi = vaddw_u16(pred_hi, vget_high_u16(swxtr));
-#endif  // (__aarch64__)
-
-      uint16x8_t pred =
-          vcombine_u16(vshrn_n_u32(pred_lo, 9), vshrn_n_u32(pred_hi, 9));
-
-      uint8x8_t predsh = vqmovn_u16(pred);
-
-      vst1_u8(dst + x, predsh);
-    }
-
-    dst += stride;
+// For width 16 and above.
+#define SMOOTH_PREDICTOR(W)                                                 \
+  static void smooth_##W##xh_neon(                                          \
+      uint8_t *dst, ptrdiff_t stride, const uint8_t *const top_row,         \
+      const uint8_t *const left_column, const int height) {                 \
+    const uint8_t top_right = top_row[(W)-1];                               \
+    const uint8_t bottom_left = left_column[height - 1];                    \
+    const uint8_t *const weights_y = smooth_weights + height - 4;           \
+                                                                            \
+    uint8x16_t top_v[4];                                                    \
+    top_v[0] = vld1q_u8(top_row);                                           \
+    if ((W) > 16) {                                                         \
+      top_v[1] = vld1q_u8(top_row + 16);                                    \
+      if ((W) == 64) {                                                      \
+        top_v[2] = vld1q_u8(top_row + 32);                                  \
+        top_v[3] = vld1q_u8(top_row + 48);                                  \
+      }                                                                     \
+    }                                                                       \
+                                                                            \
+    const uint8x8_t top_right_v = vdup_n_u8(top_right);                     \
+    const uint8x8_t bottom_left_v = vdup_n_u8(bottom_left);                 \
+                                                                            \
+    uint8x16_t weights_x_v[4];                                              \
+    weights_x_v[0] = vld1q_u8(smooth_weights + (W)-4);                      \
+    if ((W) > 16) {                                                         \
+      weights_x_v[1] = vld1q_u8(smooth_weights + (W) + 16 - 4);             \
+      if ((W) == 64) {                                                      \
+        weights_x_v[2] = vld1q_u8(smooth_weights + (W) + 32 - 4);           \
+        weights_x_v[3] = vld1q_u8(smooth_weights + (W) + 48 - 4);           \
+      }                                                                     \
+    }                                                                       \
+                                                                            \
+    uint8x16_t scaled_weights_x[4];                                         \
+    scaled_weights_x[0] = negate_s8q(weights_x_v[0]);                       \
+    if ((W) > 16) {                                                         \
+      scaled_weights_x[1] = negate_s8q(weights_x_v[1]);                     \
+      if ((W) == 64) {                                                      \
+        scaled_weights_x[2] = negate_s8q(weights_x_v[2]);                   \
+        scaled_weights_x[3] = negate_s8q(weights_x_v[3]);                   \
+      }                                                                     \
+    }                                                                       \
+                                                                            \
+    for (int y = 0; y < height; ++y) {                                      \
+      const uint8x8_t left_v = vdup_n_u8(left_column[y]);                   \
+      const uint8x8_t weights_y_v = vdup_n_u8(weights_y[y]);                \
+      const uint8x8_t scaled_weights_y = negate_s8(weights_y_v);            \
+      const uint16x8_t weighted_bl =                                        \
+          vmull_u8(scaled_weights_y, bottom_left_v);                        \
+                                                                            \
+      vst1q_u8(dst, calculate_weights_and_predq(                            \
+                        top_v[0], left_v, top_right_v, weights_y_v,         \
+                        weights_x_v[0], scaled_weights_x[0], weighted_bl)); \
+                                                                            \
+      if ((W) > 16) {                                                       \
+        vst1q_u8(dst + 16,                                                  \
+                 calculate_weights_and_predq(                               \
+                     top_v[1], left_v, top_right_v, weights_y_v,            \
+                     weights_x_v[1], scaled_weights_x[1], weighted_bl));    \
+        if ((W) == 64) {                                                    \
+          vst1q_u8(dst + 32,                                                \
+                   calculate_weights_and_predq(                             \
+                       top_v[2], left_v, top_right_v, weights_y_v,          \
+                       weights_x_v[2], scaled_weights_x[2], weighted_bl));  \
+          vst1q_u8(dst + 48,                                                \
+                   calculate_weights_and_predq(                             \
+                       top_v[3], left_v, top_right_v, weights_y_v,          \
+                       weights_x_v[3], scaled_weights_x[3], weighted_bl));  \
+        }                                                                   \
+      }                                                                     \
+                                                                            \
+      dst += stride;                                                        \
+    }                                                                       \
   }
-}
 
-void aom_smooth_predictor_16x4_neon(uint8_t *dst, ptrdiff_t stride,
-                                    const uint8_t *above, const uint8_t *left) {
-  smooth_predictor_wxh(dst, stride, above, left, 16, 4);
-}
+SMOOTH_PREDICTOR(16)
+SMOOTH_PREDICTOR(32)
+SMOOTH_PREDICTOR(64)
 
-void aom_smooth_predictor_16x8_neon(uint8_t *dst, ptrdiff_t stride,
-                                    const uint8_t *above, const uint8_t *left) {
-  smooth_predictor_wxh(dst, stride, above, left, 16, 8);
-}
+#undef SMOOTH_PREDICTOR
 
-void aom_smooth_predictor_16x16_neon(uint8_t *dst, ptrdiff_t stride,
-                                     const uint8_t *above,
-                                     const uint8_t *left) {
-  smooth_predictor_wxh(dst, stride, above, left, 16, 16);
-}
+#define SMOOTH_NXM_WIDE(W, H)                                                  \
+  void aom_smooth_predictor_##W##x##H##_neon(uint8_t *dst, ptrdiff_t y_stride, \
+                                             const uint8_t *above,             \
+                                             const uint8_t *left) {            \
+    smooth_##W##xh_neon(dst, y_stride, above, left, H);                        \
+  }
 
-void aom_smooth_predictor_16x32_neon(uint8_t *dst, ptrdiff_t stride,
-                                     const uint8_t *above,
-                                     const uint8_t *left) {
-  smooth_predictor_wxh(dst, stride, above, left, 16, 32);
-}
+SMOOTH_NXM_WIDE(16, 4)
+SMOOTH_NXM_WIDE(16, 8)
+SMOOTH_NXM_WIDE(16, 16)
+SMOOTH_NXM_WIDE(16, 32)
+SMOOTH_NXM_WIDE(16, 64)
+SMOOTH_NXM_WIDE(32, 8)
+SMOOTH_NXM_WIDE(32, 16)
+SMOOTH_NXM_WIDE(32, 32)
+SMOOTH_NXM_WIDE(32, 64)
+SMOOTH_NXM_WIDE(64, 16)
+SMOOTH_NXM_WIDE(64, 32)
+SMOOTH_NXM_WIDE(64, 64)
 
-void aom_smooth_predictor_32x8_neon(uint8_t *dst, ptrdiff_t stride,
-                                    const uint8_t *above, const uint8_t *left) {
-  smooth_predictor_wxh(dst, stride, above, left, 32, 8);
-}
-
-void aom_smooth_predictor_32x16_neon(uint8_t *dst, ptrdiff_t stride,
-                                     const uint8_t *above,
-                                     const uint8_t *left) {
-  smooth_predictor_wxh(dst, stride, above, left, 32, 16);
-}
-
-void aom_smooth_predictor_32x32_neon(uint8_t *dst, ptrdiff_t stride,
-                                     const uint8_t *above,
-                                     const uint8_t *left) {
-  smooth_predictor_wxh(dst, stride, above, left, 32, 32);
-}
-
-void aom_smooth_predictor_32x64_neon(uint8_t *dst, ptrdiff_t stride,
-                                     const uint8_t *above,
-                                     const uint8_t *left) {
-  smooth_predictor_wxh(dst, stride, above, left, 32, 64);
-}
-
-void aom_smooth_predictor_64x64_neon(uint8_t *dst, ptrdiff_t stride,
-                                     const uint8_t *above,
-                                     const uint8_t *left) {
-  smooth_predictor_wxh(dst, stride, above, left, 64, 64);
-}
-
-void aom_smooth_predictor_64x32_neon(uint8_t *dst, ptrdiff_t stride,
-                                     const uint8_t *above,
-                                     const uint8_t *left) {
-  smooth_predictor_wxh(dst, stride, above, left, 64, 32);
-}
-
-void aom_smooth_predictor_64x16_neon(uint8_t *dst, ptrdiff_t stride,
-                                     const uint8_t *above,
-                                     const uint8_t *left) {
-  smooth_predictor_wxh(dst, stride, above, left, 64, 16);
-}
-
-void aom_smooth_predictor_16x64_neon(uint8_t *dst, ptrdiff_t stride,
-                                     const uint8_t *above,
-                                     const uint8_t *left) {
-  smooth_predictor_wxh(dst, stride, above, left, 16, 64);
-}
+#undef SMOOTH_NXM_WIDE
 
 // -----------------------------------------------------------------------------
 // PAETH