Optimize av1_dr_prediction_z2 Neon implementation

Optimize av1_dr_prediction_z2_neon by cleaning and simplifying the
calculations.

This optimization is a port from SVT-AV1.

Change-Id: I38464487ff0fe4aeee2575ca9851944fc133b4e8
diff --git a/aom_dsp/arm/intrapred_neon.c b/aom_dsp/arm/intrapred_neon.c
index 6ea58b9..be63638 100644
--- a/aom_dsp/arm/intrapred_neon.c
+++ b/aom_dsp/arm/intrapred_neon.c
@@ -1497,19 +1497,19 @@
 static AOM_FORCE_INLINE void dr_prediction_z2_4xH_above_neon(
     const uint8_t *above, int upsample_above, int dx, int base_x, int y,
     uint8x8_t *a0_x, uint8x8_t *a1_x, uint16x4_t *shift0) {
-  uint16x4_t r6 = vcreate_u16(0x00C0008000400000);
-  uint16x4_t ydx = vdup_n_u16(y * dx);
+  const int x = -y * dx;
+
   if (upsample_above) {
     // Cannot use LD2 here since we only want to load eight bytes, but LD2 can
     // only load either 16 or 32.
     uint8x8_t v_tmp = vld1_u8(above + base_x);
     *a0_x = vuzp_u8(v_tmp, vdup_n_u8(0)).val[0];
     *a1_x = vuzp_u8(v_tmp, vdup_n_u8(0)).val[1];
-    *shift0 = vand_u16(vsub_u16(r6, ydx), vdup_n_u16(0x1f));
+    *shift0 = vdup_n_u16(x & 0x1f);  // ((x << upsample_above) & 0x3f) >> 1
   } else {
     *a0_x = load_unaligned_u8_4x1(above + base_x);
     *a1_x = load_unaligned_u8_4x1(above + base_x + 1);
-    *shift0 = vand_u16(vhsub_u16(r6, ydx), vdup_n_u16(0x1f));
+    *shift0 = vdup_n_u16((x & 0x3f) >> 1);
   }
 }
 
@@ -1519,19 +1519,15 @@
 #else
     const uint8_t *left,
 #endif
-    int upsample_left, int dy, int r, int min_base_y, int frac_bits_y,
-    uint16x4_t *a0_y, uint16x4_t *a1_y, uint16x4_t *shift1) {
+    int upsample_left, int dy, int r, int frac_bits_y, uint16x4_t *a0_y,
+    uint16x4_t *a1_y, uint16x4_t *shift1) {
   int16x4_t dy64 = vdup_n_s16(dy);
   int16x4_t v_1234 = vcreate_s16(0x0004000300020001);
   int16x4_t v_frac_bits_y = vdup_n_s16(-frac_bits_y);
-  int16x4_t min_base_y64 = vdup_n_s16(min_base_y);
   int16x4_t v_r6 = vdup_n_s16(r << 6);
   int16x4_t y_c64 = vmls_s16(v_r6, v_1234, dy64);
   int16x4_t base_y_c64 = vshl_s16(y_c64, v_frac_bits_y);
 
-  // Values in base_y_c64 range from -2 through 14 inclusive.
-  base_y_c64 = vmax_s16(base_y_c64, min_base_y64);
-
 #if AOM_ARCH_AARCH64
   uint8x8_t left_idx0 =
       vreinterpret_u8_s16(vadd_s16(base_y_c64, vdup_n_s16(2)));  // [0, 16]
@@ -1563,7 +1559,8 @@
 #endif  // AOM_ARCH_AARCH64
 
   if (upsample_left) {
-    *shift1 = vand_u16(vreinterpret_u16_s16(y_c64), vdup_n_u16(0x1f));
+    *shift1 = vand_u16(vreinterpret_u16_s16(y_c64),
+                       vdup_n_u16(0x1f));  // ((y << upsample_left) & 0x3f) >> 1
   } else {
     *shift1 =
         vand_u16(vshr_n_u16(vreinterpret_u16_s16(y_c64), 1), vdup_n_u16(0x1f));
@@ -1572,10 +1569,7 @@
 
 static AOM_FORCE_INLINE uint8x8_t dr_prediction_z2_8xH_above_neon(
     const uint8_t *above, int upsample_above, int dx, int base_x, int y) {
-  uint16x8_t c1234 = vcombine_u16(vcreate_u16(0x0004000300020001),
-                                  vcreate_u16(0x0008000700060005));
-  uint16x8_t ydx = vdupq_n_u16(y * dx);
-  uint16x8_t r6 = vshlq_n_u16(vextq_u16(c1234, vdupq_n_u16(0), 2), 6);
+  const int x = -y * dx;
 
   uint16x8_t shift0;
   uint8x8_t a0_x0;
@@ -1584,18 +1578,18 @@
     uint8x8x2_t v_tmp = vld2_u8(above + base_x);
     a0_x0 = v_tmp.val[0];
     a1_x0 = v_tmp.val[1];
-    shift0 = vandq_u16(vsubq_u16(r6, ydx), vdupq_n_u16(0x1f));
+    shift0 = vdupq_n_u16(x & 0x1f);  // ((x << upsample_above) & 0x3f) >> 1
   } else {
     a0_x0 = vld1_u8(above + base_x);
     a1_x0 = vld1_u8(above + base_x + 1);
-    shift0 = vandq_u16(vhsubq_u16(r6, ydx), vdupq_n_u16(0x1f));
+    shift0 = vdupq_n_u16((x & 0x3f) >> 1);
   }
 
   uint16x8_t diff0 = vsubl_u8(a1_x0, a0_x0);  // a[x+1] - a[x]
-  uint16x8_t a32 =
-      vmlal_u8(vdupq_n_u16(16), a0_x0, vdup_n_u8(32));  // a[x] * 32 + 16
+  uint16x8_t a32 = vshll_n_u8(a0_x0, 5);      // a[x] * 32
   uint16x8_t res = vmlaq_u16(a32, diff0, shift0);
-  return vshrn_n_u16(res, 5);
+
+  return vrshrn_n_u16(res, 5);
 }
 
 static AOM_FORCE_INLINE uint8x8_t dr_prediction_z2_8xH_left_neon(
@@ -1604,20 +1598,16 @@
 #else
     const uint8_t *left,
 #endif
-    int upsample_left, int dy, int r, int min_base_y, int frac_bits_y) {
+    int upsample_left, int dy, int r, int frac_bits_y) {
   int16x8_t v_r6 = vdupq_n_s16(r << 6);
   int16x8_t dy128 = vdupq_n_s16(dy);
   int16x8_t v_frac_bits_y = vdupq_n_s16(-frac_bits_y);
-  int16x8_t min_base_y128 = vdupq_n_s16(min_base_y);
 
   uint16x8_t c1234 = vcombine_u16(vcreate_u16(0x0004000300020001),
                                   vcreate_u16(0x0008000700060005));
   int16x8_t y_c128 = vmlsq_s16(v_r6, vreinterpretq_s16_u16(c1234), dy128);
   int16x8_t base_y_c128 = vshlq_s16(y_c128, v_frac_bits_y);
 
-  // Values in base_y_c128 range from -2 through 31 inclusive.
-  base_y_c128 = vmaxq_s16(base_y_c128, min_base_y128);
-
 #if AOM_ARCH_AARCH64
   uint8x16_t left_idx0 =
       vreinterpretq_u8_s16(vaddq_s16(base_y_c128, vdupq_n_s16(2)));  // [0, 33]
@@ -1635,46 +1625,39 @@
 
   uint16x8_t shift1;
   if (upsample_left) {
-    shift1 = vandq_u16(vreinterpretq_u16_s16(y_c128), vdupq_n_u16(0x1f));
+    shift1 =
+        vandq_u16(vreinterpretq_u16_s16(y_c128),
+                  vdupq_n_u16(0x1f));  // ((y << upsample_left) & 0x3f) >> 1
   } else {
     shift1 = vshrq_n_u16(
         vandq_u16(vreinterpretq_u16_s16(y_c128), vdupq_n_u16(0x3f)), 1);
   }
 
   uint16x8_t diff1 = vsubl_u8(a1_x1, a0_x1);
-  uint16x8_t a32 = vmlal_u8(vdupq_n_u16(16), a0_x1, vdup_n_u8(32));
+  uint16x8_t a32 = vshll_n_u8(a0_x1, 5);
   uint16x8_t res = vmlaq_u16(a32, diff1, shift1);
-  return vshrn_n_u16(res, 5);
+
+  return vrshrn_n_u16(res, 5);
 }
 
 static AOM_FORCE_INLINE uint8x16_t dr_prediction_z2_WxH_above_neon(
     const uint8_t *above, int dx, int base_x, int y, int j) {
-  uint16x8x2_t c0123 = { { vcombine_u16(vcreate_u16(0x0003000200010000),
-                                        vcreate_u16(0x0007000600050004)),
-                           vcombine_u16(vcreate_u16(0x000B000A00090008),
-                                        vcreate_u16(0x000F000E000D000C)) } };
-  uint16x8_t j256 = vdupq_n_u16(j);
-  uint16x8_t ydx = vdupq_n_u16((uint16_t)(y * dx));
-
+  const int x = -y * dx;
   const uint8x16_t a0_x128 = vld1q_u8(above + base_x + j);
   const uint8x16_t a1_x128 = vld1q_u8(above + base_x + j + 1);
-  uint16x8_t res6_0 = vshlq_n_u16(vaddq_u16(c0123.val[0], j256), 6);
-  uint16x8_t res6_1 = vshlq_n_u16(vaddq_u16(c0123.val[1], j256), 6);
-  uint16x8_t shift0 =
-      vshrq_n_u16(vandq_u16(vsubq_u16(res6_0, ydx), vdupq_n_u16(0x3f)), 1);
-  uint16x8_t shift1 =
-      vshrq_n_u16(vandq_u16(vsubq_u16(res6_1, ydx), vdupq_n_u16(0x3f)), 1);
+  const uint16x8_t shift = vdupq_n_u16((x & 0x3f) >> 1);
+
   // a[x+1] - a[x]
   uint16x8_t diff0 = vsubl_u8(vget_low_u8(a1_x128), vget_low_u8(a0_x128));
   uint16x8_t diff1 = vsubl_u8(vget_high_u8(a1_x128), vget_high_u8(a0_x128));
-  // a[x] * 32 + 16
-  uint16x8_t a32_0 =
-      vmlal_u8(vdupq_n_u16(16), vget_low_u8(a0_x128), vdup_n_u8(32));
-  uint16x8_t a32_1 =
-      vmlal_u8(vdupq_n_u16(16), vget_high_u8(a0_x128), vdup_n_u8(32));
-  uint16x8_t res0 = vmlaq_u16(a32_0, diff0, shift0);
-  uint16x8_t res1 = vmlaq_u16(a32_1, diff1, shift1);
-  return vcombine_u8(vshrn_n_u16(res0, 5), vshrn_n_u16(res1, 5));
+
+  // a[x] * 32
+  uint16x8_t a32_0 = vshll_n_u8(vget_low_u8(a0_x128), 5);
+  uint16x8_t a32_1 = vshll_n_u8(vget_high_u8(a0_x128), 5);
+  uint16x8_t res0 = vmlaq_u16(a32_0, diff0, shift);
+  uint16x8_t res1 = vmlaq_u16(a32_1, diff1, shift);
+
+  return vcombine_u8(vrshrn_n_u16(res0, 5), vrshrn_n_u16(res1, 5));
 }
 
 static AOM_FORCE_INLINE uint8x16_t dr_prediction_z2_WxH_left_neon(
@@ -1686,39 +1669,25 @@
     int dy, int r, int j) {
   // here upsample_above and upsample_left are 0 by design of
   // av1_use_intra_edge_upsample
-  const int min_base_y = -1;
-
-  int16x8_t min_base_y256 = vdupq_n_s16(min_base_y);
-  int16x8_t half_min_base_y256 = vdupq_n_s16(min_base_y >> 1);
   int16x8_t dy256 = vdupq_n_s16(dy);
   uint16x8_t j256 = vdupq_n_u16(j);
 
-  uint16x8x2_t c0123 = { { vcombine_u16(vcreate_u16(0x0003000200010000),
-                                        vcreate_u16(0x0007000600050004)),
-                           vcombine_u16(vcreate_u16(0x000B000A00090008),
-                                        vcreate_u16(0x000F000E000D000C)) } };
-  uint16x8x2_t c1234 = { { vaddq_u16(c0123.val[0], vdupq_n_u16(1)),
-                           vaddq_u16(c0123.val[1], vdupq_n_u16(1)) } };
+  uint16x8x2_t c1234 = { { vcombine_u16(vcreate_u16(0x0004000300020001),
+                                        vcreate_u16(0x0008000700060005)),
+                           vcombine_u16(vcreate_u16(0x000C000B000A0009),
+                                        vcreate_u16(0x0010000F000E000D)) } };
 
   int16x8_t v_r6 = vdupq_n_s16(r << 6);
 
   int16x8_t c256_0 = vreinterpretq_s16_u16(vaddq_u16(j256, c1234.val[0]));
   int16x8_t c256_1 = vreinterpretq_s16_u16(vaddq_u16(j256, c1234.val[1]));
-  int16x8_t mul16_lo = vreinterpretq_s16_u16(
-      vminq_u16(vreinterpretq_u16_s16(vmulq_s16(c256_0, dy256)),
-                vreinterpretq_u16_s16(half_min_base_y256)));
-  int16x8_t mul16_hi = vreinterpretq_s16_u16(
-      vminq_u16(vreinterpretq_u16_s16(vmulq_s16(c256_1, dy256)),
-                vreinterpretq_u16_s16(half_min_base_y256)));
-  int16x8_t y_c256_lo = vsubq_s16(v_r6, mul16_lo);
-  int16x8_t y_c256_hi = vsubq_s16(v_r6, mul16_hi);
+
+  int16x8_t y_c256_lo = vmlsq_s16(v_r6, c256_0, dy256);
+  int16x8_t y_c256_hi = vmlsq_s16(v_r6, c256_1, dy256);
 
   int16x8_t base_y_c256_lo = vshrq_n_s16(y_c256_lo, 6);
   int16x8_t base_y_c256_hi = vshrq_n_s16(y_c256_hi, 6);
 
-  base_y_c256_lo = vmaxq_s16(min_base_y256, base_y_c256_lo);
-  base_y_c256_hi = vmaxq_s16(min_base_y256, base_y_c256_hi);
-
 #if !AOM_ARCH_AARCH64
   int16_t min_y = vgetq_lane_s16(base_y_c256_hi, 7);
   int16_t max_y = vgetq_lane_s16(base_y_c256_lo, 0);
@@ -1799,13 +1768,13 @@
   uint16x8_t diff_lo = vsubl_u8(a1_y0, a0_y0);
   uint16x8_t diff_hi = vsubl_u8(a1_y1, a0_y1);
   // a[x] * 32 + 16
-  uint16x8_t a32_lo = vmlal_u8(vdupq_n_u16(16), a0_y0, vdup_n_u8(32));
-  uint16x8_t a32_hi = vmlal_u8(vdupq_n_u16(16), a0_y1, vdup_n_u8(32));
+  uint16x8_t a32_lo = vshll_n_u8(a0_y0, 5);
+  uint16x8_t a32_hi = vshll_n_u8(a0_y1, 5);
 
   uint16x8_t res0 = vmlaq_u16(a32_lo, diff_lo, shifty_lo);
   uint16x8_t res1 = vmlaq_u16(a32_hi, diff_hi, shifty_hi);
 
-  return vcombine_u8(vshrn_n_u16(res0, 5), vshrn_n_u16(res1, 5));
+  return vcombine_u8(vrshrn_n_u16(res0, 5), vrshrn_n_u16(res1, 5));
 }
 
 static void dr_prediction_z2_4xH_neon(int H, uint8_t *dst, ptrdiff_t stride,
@@ -1813,17 +1782,12 @@
                                       int upsample_above, int upsample_left,
                                       int dx, int dy) {
   const int min_base_x = -(1 << upsample_above);
-  const int min_base_y = -(1 << upsample_left);
   const int frac_bits_x = 6 - upsample_above;
   const int frac_bits_y = 6 - upsample_left;
 
   assert(dx > 0);
-  // pre-filter above pixels
-  // store in temp buffers:
-  //   above[x] * 32 + 16
-  //   above[x+1] - above[x]
-  // final pixels will be calculated as:
-  //   (above[x] * 32 + 16 + (above[x+1] - above[x]) * shift) >> 5
+  //   above0 * (32 - shift) + above1 * shift
+  // = (above1 - above0) * shift + above0 * 32
 
 #if AOM_ARCH_AARCH64
   // Use ext rather than loading left + 14 directly to avoid over-read.
@@ -1852,11 +1816,10 @@
       uint8x8_t a1_x = a1_x_u8;
 
       uint16x8_t diff = vsubl_u8(a1_x, a0_x);  // a[x+1] - a[x]
-      uint16x8_t a32 =
-          vmlal_u8(vdupq_n_u16(16), a0_x, vdup_n_u8(32));  // a[x] * 32 + 16
+      uint16x8_t a32 = vshll_n_u8(a0_x, 5);
       uint16x8_t res =
           vmlaq_u16(a32, diff, vcombine_u16(shift0, vdup_n_u16(0)));
-      uint8x8_t resx = vshrn_n_u16(res, 5);
+      uint8x8_t resx = vrshrn_n_u16(res, 5);
       vst1_lane_u32((uint32_t *)dst, vreinterpret_u32_u8(resx), 0);
     } else if (base_min_diff < 4) {
       uint8x8_t a0_x_u8, a1_x_u8;
@@ -1869,17 +1832,16 @@
       uint16x4_t a0_y;
       uint16x4_t a1_y;
       uint16x4_t shift1;
-      dr_prediction_z2_4xH_left_neon(LEFT, upsample_left, dy, r, min_base_y,
-                                     frac_bits_y, &a0_y, &a1_y, &shift1);
+      dr_prediction_z2_4xH_left_neon(LEFT, upsample_left, dy, r, frac_bits_y,
+                                     &a0_y, &a1_y, &shift1);
       a0_x = vcombine_u16(vget_low_u16(a0_x), a0_y);
       a1_x = vcombine_u16(vget_low_u16(a1_x), a1_y);
 
       uint16x8_t shift = vcombine_u16(shift0, shift1);
       uint16x8_t diff = vsubq_u16(a1_x, a0_x);  // a[x+1] - a[x]
-      uint16x8_t a32 =
-          vmlaq_n_u16(vdupq_n_u16(16), a0_x, 32);  // a[x] * 32 + 16
+      uint16x8_t a32 = vshlq_n_u16(a0_x, 5);
       uint16x8_t res = vmlaq_u16(a32, diff, shift);
-      uint8x8_t resx = vshrn_n_u16(res, 5);
+      uint8x8_t resx = vrshrn_n_u16(res, 5);
       uint8x8_t resy = vext_u8(resx, vdup_n_u8(0), 4);
 
       uint8x8_t mask = vld1_u8(BaseMask[base_min_diff]);
@@ -1888,12 +1850,12 @@
     } else {
       uint16x4_t a0_y, a1_y;
       uint16x4_t shift1;
-      dr_prediction_z2_4xH_left_neon(LEFT, upsample_left, dy, r, min_base_y,
-                                     frac_bits_y, &a0_y, &a1_y, &shift1);
-      uint16x4_t diff = vsub_u16(a1_y, a0_y);                 // a[x+1] - a[x]
-      uint16x4_t a32 = vmla_n_u16(vdup_n_u16(16), a0_y, 32);  // a[x] * 32 + 16
+      dr_prediction_z2_4xH_left_neon(LEFT, upsample_left, dy, r, frac_bits_y,
+                                     &a0_y, &a1_y, &shift1);
+      uint16x4_t diff = vsub_u16(a1_y, a0_y);  // a[x+1] - a[x]
+      uint16x4_t a32 = vshl_n_u16(a0_y, 5);
       uint16x4_t res = vmla_u16(a32, diff, shift1);
-      uint8x8_t resy = vshrn_n_u16(vcombine_u16(res, vdup_n_u16(0)), 5);
+      uint8x8_t resy = vrshrn_n_u16(vcombine_u16(res, vdup_n_u16(0)), 5);
 
       vst1_lane_u32((uint32_t *)dst, vreinterpret_u32_u8(resy), 0);
     }
@@ -1908,16 +1870,11 @@
                                       int upsample_above, int upsample_left,
                                       int dx, int dy) {
   const int min_base_x = -(1 << upsample_above);
-  const int min_base_y = -(1 << upsample_left);
   const int frac_bits_x = 6 - upsample_above;
   const int frac_bits_y = 6 - upsample_left;
 
-  // pre-filter above pixels
-  // store in temp buffers:
-  //   above[x] * 32 + 16
-  //   above[x+1] - above[x]
-  // final pixels will be calculated as:
-  //   (above[x] * 32 + 16 + (above[x+1] - above[x]) * shift) >> 5
+  //   above0 * (32 - shift) + above1 * shift
+  // = (above1 - above0) * shift + above0 * 32
 
 #if AOM_ARCH_AARCH64
   // Use ext rather than loading left + 30 directly to avoid over-read.
@@ -1945,14 +1902,14 @@
     } else if (base_min_diff < 8) {
       uint8x8_t resx =
           dr_prediction_z2_8xH_above_neon(above, upsample_above, dx, base_x, y);
-      uint8x8_t resy = dr_prediction_z2_8xH_left_neon(
-          LEFT, upsample_left, dy, r, min_base_y, frac_bits_y);
+      uint8x8_t resy = dr_prediction_z2_8xH_left_neon(LEFT, upsample_left, dy,
+                                                      r, frac_bits_y);
       uint8x8_t mask = vld1_u8(BaseMask[base_min_diff]);
       uint8x8_t resxy = vbsl_u8(mask, resy, resx);
       vst1_u8(dst, resxy);
     } else {
-      uint8x8_t resy = dr_prediction_z2_8xH_left_neon(
-          LEFT, upsample_left, dy, r, min_base_y, frac_bits_y);
+      uint8x8_t resy = dr_prediction_z2_8xH_left_neon(LEFT, upsample_left, dy,
+                                                      r, frac_bits_y);
       vst1_u8(dst, resy);
     }