intrapred_neon: update aom_smooth_predictor_8xN_neon

to the version based on libgav1 @ v0.17.0-83-ge54abf5c

~1.44-1.72x faster depending on the size

Bug: b/231616924
Change-Id: I0ec9bfbdb41d8984594edcd5427542df080c23a9
diff --git a/aom_dsp/arm/intrapred_neon.c b/aom_dsp/arm/intrapred_neon.c
index 8f4e352..dfd56a3 100644
--- a/aom_dsp/arm/intrapred_neon.c
+++ b/aom_dsp/arm/intrapred_neon.c
@@ -2690,6 +2690,52 @@
   }
 }
 
+static INLINE uint8x8_t calculate_pred(const uint16x8_t weighted_top_bl,
+                                       const uint16x8_t weighted_left_tr) {
+  // Maximum value of each parameter: 0xFF00
+  const uint16x8_t avg = vhaddq_u16(weighted_top_bl, weighted_left_tr);
+  return vrshrn_n_u16(avg, SMOOTH_WEIGHT_LOG2_SCALE);
+}
+
+static INLINE uint8x8_t calculate_weights_and_pred(
+    const uint8x8_t top, const uint8x8_t left, const uint16x8_t weighted_tr,
+    const uint8x8_t bottom_left, const uint8x8_t weights_x,
+    const uint8x8_t scaled_weights_y, const uint8x8_t weights_y) {
+  const uint16x8_t weighted_top = vmull_u8(weights_y, top);
+  const uint16x8_t weighted_top_bl =
+      vmlal_u8(weighted_top, scaled_weights_y, bottom_left);
+  const uint16x8_t weighted_left_tr = vmlal_u8(weighted_tr, weights_x, left);
+  return calculate_pred(weighted_top_bl, weighted_left_tr);
+}
+
+static void smooth_8xh_neon(uint8_t *dst, ptrdiff_t stride,
+                            const uint8_t *const top_row,
+                            const uint8_t *const left_column,
+                            const int height) {
+  const uint8_t top_right = top_row[7];
+  const uint8_t bottom_left = left_column[height - 1];
+  const uint8_t *const weights_y = smooth_weights + height - 4;
+
+  const uint8x8_t top_v = vld1_u8(top_row);
+  const uint8x8_t top_right_v = vdup_n_u8(top_right);
+  const uint8x8_t bottom_left_v = vdup_n_u8(bottom_left);
+  const uint8x8_t weights_x_v = vld1_u8(smooth_weights + 4);
+  const uint8x8_t scaled_weights_x = negate_s8(weights_x_v);
+  const uint16x8_t weighted_tr = vmull_u8(scaled_weights_x, top_right_v);
+
+  for (int y = 0; y < height; ++y) {
+    const uint8x8_t left_v = vdup_n_u8(left_column[y]);
+    const uint8x8_t weights_y_v = vdup_n_u8(weights_y[y]);
+    const uint8x8_t scaled_weights_y = negate_s8(weights_y_v);
+    const uint8x8_t result =
+        calculate_weights_and_pred(top_v, left_v, weighted_tr, bottom_left_v,
+                                   weights_x_v, scaled_weights_y, weights_y_v);
+
+    vst1_u8(dst, result);
+    dst += stride;
+  }
+}
+
 #define SMOOTH_NXM(W, H)                                                       \
   void aom_smooth_predictor_##W##x##H##_neon(uint8_t *dst, ptrdiff_t y_stride, \
                                              const uint8_t *above,             \
@@ -2699,220 +2745,14 @@
 
 SMOOTH_NXM(4, 4)
 SMOOTH_NXM(4, 8)
+SMOOTH_NXM(8, 4)
+SMOOTH_NXM(8, 8)
 SMOOTH_NXM(4, 16)
+SMOOTH_NXM(8, 16)
+SMOOTH_NXM(8, 32)
 
 #undef SMOOTH_NXM
 
-// pixels[0]: above and below_pred interleave vector, first half
-// pixels[1]: above and below_pred interleave vector, second half
-// pixels[2]: left vector
-// pixels[3]: right_pred vector
-// pixels[4]: above and below_pred interleave vector, first half
-// pixels[5]: above and below_pred interleave vector, second half
-// pixels[6]: left vector + 16
-// pixels[7]: right_pred vector
-static INLINE void load_pixel_w8(const uint8_t *above, const uint8_t *left,
-                                 int height, uint8x16_t *pixels) {
-  pixels[0] = vreinterpretq_u8_u16(vmovl_u8(vld1_u8(above)));
-  pixels[1] = vreinterpretq_u8_u16(vdupq_n_u16((uint16_t)left[height - 1]));
-  pixels[3] = vreinterpretq_u8_u16(vdupq_n_u16((uint16_t)above[7]));
-
-  if (height == 4) {
-    const uint32x4_t zero32 = vdupq_n_u32(0);
-    pixels[2] =
-        vreinterpretq_u8_u32(vld1q_lane_u32((const uint32_t *)left, zero32, 0));
-  } else if (height == 8) {
-    const uint64x2_t zero64 = vdupq_n_u64(0);
-    pixels[2] = vreinterpretq_u8_u64(
-        vsetq_lane_u64(((const uint64_t *)left)[0], zero64, 0));
-  } else if (height == 16) {
-    pixels[2] = vld1q_u8(left);
-  } else {
-    pixels[2] = vld1q_u8(left);
-    pixels[4] = pixels[0];
-    pixels[5] = pixels[1];
-    pixels[6] = vld1q_u8(left + 16);
-    pixels[7] = pixels[3];
-  }
-}
-
-// weight_h[0]: weight_h vector
-// weight_h[1]: scale - weight_h vector
-// weight_h[2]: same as [0], offset 8
-// weight_h[3]: same as [1], offset 8
-// weight_h[4]: same as [0], offset 16
-// weight_h[5]: same as [1], offset 16
-// weight_h[6]: same as [0], offset 24
-// weight_h[7]: same as [1], offset 24
-// weight_w[0]: weights_w and scale - weights_w interleave vector, first half
-// weight_w[1]: weights_w and scale - weights_w interleave vector, second half
-static INLINE void load_weight_w8(int height, uint16x8_t *weight_h,
-                                  uint16x8_t *weight_w) {
-  const uint8x16_t zero = vdupq_n_u8(0);
-  const int we_offset = height < 8 ? 0 : 4;
-  uint8x16_t we = vld1q_u8(&smooth_weights[we_offset]);
-#if defined(__aarch64__)
-  weight_h[0] = vreinterpretq_u16_u8(vzip1q_u8(we, zero));
-#else
-  weight_h[0] = vreinterpretq_u16_u8(vzipq_u8(we, zero).val[0]);
-#endif  // (__aarch64__)
-  const uint16x8_t d = vdupq_n_u16(256);
-  weight_h[1] = vsubq_u16(d, weight_h[0]);
-
-  if (height == 4) {
-    we = vextq_u8(we, zero, 4);
-#if defined(__aarch64__)
-    weight_w[0] = vreinterpretq_u16_u8(vzip1q_u8(we, zero));
-#else
-    weight_w[0] = vmovl_u8(vget_low_u8(we));
-#endif  // (__aarch64__)
-    weight_w[1] = vsubq_u16(d, weight_w[0]);
-  } else {
-    weight_w[0] = weight_h[0];
-    weight_w[1] = weight_h[1];
-  }
-
-  if (height == 16) {
-    we = vld1q_u8(&smooth_weights[12]);
-    const uint8x16x2_t weight_h_02 = vzipq_u8(we, zero);
-    weight_h[0] = vreinterpretq_u16_u8(weight_h_02.val[0]);
-    weight_h[1] = vsubq_u16(d, weight_h[0]);
-    weight_h[2] = vreinterpretq_u16_u8(weight_h_02.val[1]);
-    weight_h[3] = vsubq_u16(d, weight_h[2]);
-  } else if (height == 32) {
-    const uint8x16_t weight_lo = vld1q_u8(&smooth_weights[28]);
-    const uint8x16x2_t weight_h_02 = vzipq_u8(weight_lo, zero);
-    weight_h[0] = vreinterpretq_u16_u8(weight_h_02.val[0]);
-    weight_h[1] = vsubq_u16(d, weight_h[0]);
-    weight_h[2] = vreinterpretq_u16_u8(weight_h_02.val[1]);
-    weight_h[3] = vsubq_u16(d, weight_h[2]);
-    const uint8x16_t weight_hi = vld1q_u8(&smooth_weights[28 + 16]);
-    const uint8x16x2_t weight_h_46 = vzipq_u8(weight_hi, zero);
-    weight_h[4] = vreinterpretq_u16_u8(weight_h_46.val[0]);
-    weight_h[5] = vsubq_u16(d, weight_h[4]);
-    weight_h[6] = vreinterpretq_u16_u8(weight_h_46.val[1]);
-    weight_h[7] = vsubq_u16(d, weight_h[6]);
-  }
-}
-
-static INLINE void smooth_pred_8xh(const uint8x16_t *pixels,
-                                   const uint16x8_t *wh, const uint16x8_t *ww,
-                                   int h, uint8_t *dst, ptrdiff_t stride,
-                                   int second_half) {
-  const uint16x8_t one = vdupq_n_u16(1);
-  const uint16x8_t inc = vdupq_n_u16(0x202);
-  uint16x8_t rep = second_half ? vdupq_n_u16((uint16_t)0x8008)
-                               : vdupq_n_u16((uint16_t)0x8000);
-  uint16x8_t d = vdupq_n_u16(0x100);
-
-#if !defined(__aarch64__)
-  const uint8x8x2_t v_split1 = { { vget_low_u8(vreinterpretq_u8_u16(wh[0])),
-                                   vget_high_u8(
-                                       vreinterpretq_u8_u16(wh[0])) } };
-  const uint8x8x2_t v_split2 = { { vget_low_u8(vreinterpretq_u8_u16(wh[1])),
-                                   vget_high_u8(
-                                       vreinterpretq_u8_u16(wh[1])) } };
-  const uint8x8x2_t v_split3 = { { vget_low_u8(pixels[2]),
-                                   vget_high_u8(pixels[2]) } };
-#endif
-
-  for (int i = 0; i < h; ++i) {
-#if defined(__aarch64__)
-    const uint8x16_t wg_wg =
-        vqtbl1q_u8(vreinterpretq_u8_u16(wh[0]), vreinterpretq_u8_u16(d));
-    const uint8x16_t sc_sc =
-        vqtbl1q_u8(vreinterpretq_u8_u16(wh[1]), vreinterpretq_u8_u16(d));
-#else
-    const uint8x8_t v_d_lo = vreinterpret_u8_u16(vget_low_u16(d));
-    const uint8x8_t v_d_hi = vreinterpret_u8_u16(vget_high_u16(d));
-    const uint8x16_t wg_wg =
-        vcombine_u8(vtbl2_u8(v_split1, v_d_lo), vtbl2_u8(v_split1, v_d_hi));
-    const uint8x16_t sc_sc =
-        vcombine_u8(vtbl2_u8(v_split2, v_d_lo), vtbl2_u8(v_split2, v_d_hi));
-#endif  // (__aarch64__)
-    uint16x8_t s01 =
-        vmulq_u16(vreinterpretq_u16_u8(pixels[0]), vreinterpretq_u16_u8(wg_wg));
-    s01 = vmlaq_u16(s01, vreinterpretq_u16_u8(pixels[1]),
-                    vreinterpretq_u16_u8(sc_sc));
-#if defined(__aarch64__)
-    const uint8x16_t b = vqtbl1q_u8(pixels[2], vreinterpretq_u8_u16(rep));
-#else
-    const uint8x16_t b = vcombine_u8(
-        vtbl2_u8(v_split3, vget_low_u8(vreinterpretq_u8_u16(rep))),
-        vtbl2_u8(v_split3, vget_high_u8(vreinterpretq_u8_u16(rep))));
-#endif  // (__aarch64__)
-    uint16x8_t sum0 = vmulq_u16(vreinterpretq_u16_u8(b), ww[0]);
-    sum0 = vmlaq_u16(sum0, vreinterpretq_u16_u8(pixels[3]), ww[1]);
-
-    uint32x4_t s0 = vaddl_u16(vget_low_u16(s01), vget_low_u16(sum0));
-#if defined(__aarch64__)
-    uint32x4_t s1 = vaddl_high_u16(s01, sum0);
-#else
-    uint32x4_t s1 = vaddl_u16(vget_high_u16(s01), vget_high_u16(sum0));
-#endif  // (__aarch64__)
-
-    sum0 = vcombine_u16(vqrshrn_n_u32(s0, 9), vqrshrn_n_u32(s1, 9));
-    uint8x8_t predsh = vqmovn_u16(sum0);
-    vst1_u8(dst, predsh);
-
-    dst += stride;
-    rep = vaddq_u16(rep, one);
-    d = vaddq_u16(d, inc);
-  }
-}
-
-void aom_smooth_predictor_8x4_neon(uint8_t *dst, ptrdiff_t stride,
-                                   const uint8_t *above, const uint8_t *left) {
-  uint8x16_t pixels[4];
-  load_pixel_w8(above, left, 4, pixels);
-
-  uint16x8_t wh[4], ww[2];
-  load_weight_w8(4, wh, ww);
-
-  smooth_pred_8xh(pixels, wh, ww, 4, dst, stride, 0);
-}
-
-void aom_smooth_predictor_8x8_neon(uint8_t *dst, ptrdiff_t stride,
-                                   const uint8_t *above, const uint8_t *left) {
-  uint8x16_t pixels[4];
-  load_pixel_w8(above, left, 8, pixels);
-
-  uint16x8_t wh[4], ww[2];
-  load_weight_w8(8, wh, ww);
-
-  smooth_pred_8xh(pixels, wh, ww, 8, dst, stride, 0);
-}
-
-void aom_smooth_predictor_8x16_neon(uint8_t *dst, ptrdiff_t stride,
-                                    const uint8_t *above, const uint8_t *left) {
-  uint8x16_t pixels[4];
-  load_pixel_w8(above, left, 16, pixels);
-
-  uint16x8_t wh[4], ww[2];
-  load_weight_w8(16, wh, ww);
-
-  smooth_pred_8xh(pixels, wh, ww, 8, dst, stride, 0);
-  dst += stride << 3;
-  smooth_pred_8xh(pixels, &wh[2], ww, 8, dst, stride, 1);
-}
-
-void aom_smooth_predictor_8x32_neon(uint8_t *dst, ptrdiff_t stride,
-                                    const uint8_t *above, const uint8_t *left) {
-  uint8x16_t pixels[8];
-  load_pixel_w8(above, left, 32, pixels);
-
-  uint16x8_t wh[8], ww[2];
-  load_weight_w8(32, wh, ww);
-
-  smooth_pred_8xh(&pixels[0], wh, ww, 8, dst, stride, 0);
-  dst += stride << 3;
-  smooth_pred_8xh(&pixels[0], &wh[2], ww, 8, dst, stride, 1);
-  dst += stride << 3;
-  smooth_pred_8xh(&pixels[4], &wh[4], ww, 8, dst, stride, 0);
-  dst += stride << 3;
-  smooth_pred_8xh(&pixels[4], &wh[6], ww, 8, dst, stride, 1);
-}
-
 static INLINE void smooth_predictor_wxh(uint8_t *dst, ptrdiff_t stride,
                                         const uint8_t *above,
                                         const uint8_t *left, uint32_t bw,