[arm]: Improve av1_quantize_fp_32x32_neon().

1.05x to 1.24x faster than the previous version
depending on the last nonzero coeff position.

Bug: b/231719821

Change-Id: I5695d8912a36419142ec4ff7e9df852816df94b6
diff --git a/av1/encoder/arm/neon/quantize_neon.c b/av1/encoder/arm/neon/quantize_neon.c
index 289218d..8b5888f 100644
--- a/av1/encoder/arm/neon/quantize_neon.c
+++ b/av1/encoder/arm/neon/quantize_neon.c
@@ -11,6 +11,7 @@
 
 #include <arm_neon.h>
 
+#include <assert.h>
 #include <math.h>
 
 #include "aom_dsp/arm/mem_neon.h"
@@ -173,6 +174,115 @@
   *eob_ptr = get_max_eob(v_eobmax_76543210);
 }
 
+static INLINE uint16x8_t quantize_fp_logscale_8(
+    const tran_low_t *coeff_ptr, tran_low_t *qcoeff_ptr,
+    tran_low_t *dqcoeff_ptr, int16x8_t v_quant, int16x8_t v_dequant,
+    int16x8_t v_round, int16x8_t v_zero, int log_scale) {
+  const int16x8_t v_log_scale_minus_1 = vdupq_n_s16(log_scale - 1);
+  const int16x8_t v_neg_log_scale_plus_1 = vdupq_n_s16(-(1 + log_scale));
+  const int16x8_t v_coeff = load_tran_low_to_s16q(coeff_ptr);
+  const int16x8_t v_coeff_sign = vshrq_n_s16(v_coeff, 15);
+  const int16x8_t v_abs_coeff = vabsq_s16(v_coeff);
+  const uint16x8_t v_mask =
+      vcgeq_s16(v_abs_coeff, vshlq_s16(v_dequant, v_neg_log_scale_plus_1));
+  // const int64_t tmp = vmask ? (int64_t)abs_coeff + log_scaled_round : 0
+  const int16x8_t v_tmp = vandq_s16(vqaddq_s16(v_abs_coeff, v_round),
+                                    vreinterpretq_s16_u16(v_mask));
+  const int16x8_t v_tmp2 =
+      vqdmulhq_s16(vshlq_s16(v_tmp, v_log_scale_minus_1), v_quant);
+  const uint16x8_t v_nz_mask = vcgtq_s16(v_tmp2, v_zero);
+  const int16x8_t v_qcoeff =
+      vsubq_s16(veorq_s16(v_tmp2, v_coeff_sign), v_coeff_sign);
+  // Multiplying by dequant here will use all 16 bits. Cast to unsigned before
+  // shifting right. (vshlq_s16 will shift right if shift value is negative)
+  const uint16x8_t v_abs_dqcoeff =
+      vshlq_u16(vreinterpretq_u16_s16(vmulq_s16(v_tmp2, v_dequant)),
+                vdupq_n_s16(-log_scale));
+  const int16x8_t v_dqcoeff =
+      vsubq_s16(veorq_s16(vreinterpretq_s16_u16(v_abs_dqcoeff), v_coeff_sign),
+                v_coeff_sign);
+  store_s16q_to_tran_low(qcoeff_ptr, v_qcoeff);
+  store_s16q_to_tran_low(dqcoeff_ptr, v_dqcoeff);
+  return v_nz_mask;
+}
+
+static INLINE uint32_t sum_abs_coeff(const uint16x8_t a) {
+#if defined(__aarch64__)
+  return vaddvq_u16(a);
+#else
+  const uint32x4_t b = vpaddlq_u16(a);
+  const uint64x2_t c = vpaddlq_u32(b);
+  const uint64x1_t d = vadd_u64(vget_low_u64(c), vget_high_u64(c));
+  return (uint32_t)vget_lane_u64(d, 0);
+#endif
+}
+
+static void quantize_fp_no_qmatrix_neon(
+    const tran_low_t *coeff_ptr, intptr_t n_coeffs, const int16_t *round_ptr,
+    const int16_t *quant_ptr, tran_low_t *qcoeff_ptr, tran_low_t *dqcoeff_ptr,
+    const int16_t *dequant_ptr, uint16_t *eob_ptr, const int16_t *iscan,
+    int log_scale) {
+  const int16x8_t v_zero = vdupq_n_s16(0);
+  int16x8_t v_quant = vld1q_s16(quant_ptr);
+  int16x8_t v_dequant = vld1q_s16(dequant_ptr);
+  const int16x8_t v_round_no_scale = vld1q_s16(round_ptr);
+  int16x8_t v_round =
+      vqrdmulhq_n_s16(v_round_no_scale, (int16_t)(1 << (15 - log_scale)));
+  int16x8_t v_eobmax_76543210 = vdupq_n_s16(-1);
+  intptr_t non_zero_count = n_coeffs;
+
+  assert(n_coeffs > 16);
+  // Pre-scan pass
+  const int16x8_t v_dequant_scaled =
+      vshlq_s16(v_dequant, vdupq_n_s16(-(1 + log_scale)));
+  const int16x8_t v_zbin_s16 =
+      vdupq_lane_s16(vget_low_s16(v_dequant_scaled), 1);
+  intptr_t i = n_coeffs;
+  do {
+    const int16x8_t v_coeff_a = load_tran_low_to_s16q(coeff_ptr + i - 8);
+    const int16x8_t v_coeff_b = load_tran_low_to_s16q(coeff_ptr + i - 16);
+    const int16x8_t v_abs_coeff_a = vabsq_s16(v_coeff_a);
+    const int16x8_t v_abs_coeff_b = vabsq_s16(v_coeff_b);
+    const uint16x8_t v_mask_a = vcgeq_s16(v_abs_coeff_a, v_zbin_s16);
+    const uint16x8_t v_mask_b = vcgeq_s16(v_abs_coeff_b, v_zbin_s16);
+    // If the coefficient is in the base ZBIN range, then discard.
+    if (sum_abs_coeff(v_mask_a) + sum_abs_coeff(v_mask_b) == 0) {
+      non_zero_count -= 16;
+    } else {
+      break;
+    }
+    i -= 16;
+  } while (i > 0);
+
+  const intptr_t remaining_zcoeffs = n_coeffs - non_zero_count;
+  memset(qcoeff_ptr + non_zero_count, 0,
+         remaining_zcoeffs * sizeof(*qcoeff_ptr));
+  memset(dqcoeff_ptr + non_zero_count, 0,
+         remaining_zcoeffs * sizeof(*dqcoeff_ptr));
+
+  // process dc and the first seven ac coeffs
+  uint16x8_t v_nz_mask =
+      quantize_fp_logscale_8(coeff_ptr, qcoeff_ptr, dqcoeff_ptr, v_quant,
+                             v_dequant, v_round, v_zero, log_scale);
+  v_eobmax_76543210 = get_max_lane_eob(iscan, v_eobmax_76543210, v_nz_mask);
+  // overwrite the dc constants with ac constants
+  v_quant = vdupq_lane_s16(vget_low_s16(v_quant), 1);
+  v_dequant = vdupq_lane_s16(vget_low_s16(v_dequant), 1);
+  v_round = vdupq_lane_s16(vget_low_s16(v_round), 1);
+
+  for (intptr_t count = non_zero_count - 8; count > 0; count -= 8) {
+    coeff_ptr += 8;
+    qcoeff_ptr += 8;
+    dqcoeff_ptr += 8;
+    iscan += 8;
+    v_nz_mask =
+        quantize_fp_logscale_8(coeff_ptr, qcoeff_ptr, dqcoeff_ptr, v_quant,
+                               v_dequant, v_round, v_zero, log_scale);
+    v_eobmax_76543210 = get_max_lane_eob(iscan, v_eobmax_76543210, v_nz_mask);
+  }
+  *eob_ptr = get_max_eob(v_eobmax_76543210);
+}
+
 void av1_quantize_fp_32x32_neon(const tran_low_t *coeff_ptr, intptr_t n_coeffs,
                                 const int16_t *zbin_ptr,
                                 const int16_t *round_ptr,
@@ -181,93 +291,12 @@
                                 tran_low_t *qcoeff_ptr, tran_low_t *dqcoeff_ptr,
                                 const int16_t *dequant_ptr, uint16_t *eob_ptr,
                                 const int16_t *scan, const int16_t *iscan) {
-  const int log_scale = 1;
-  const int rounding[2] = { ROUND_POWER_OF_TWO(round_ptr[0], log_scale),
-                            ROUND_POWER_OF_TWO(round_ptr[1], log_scale) };
-
   (void)zbin_ptr;
   (void)quant_shift_ptr;
   (void)scan;
-
-  memset(qcoeff_ptr, 0, n_coeffs * sizeof(*qcoeff_ptr));
-  memset(dqcoeff_ptr, 0, n_coeffs * sizeof(*dqcoeff_ptr));
-
-  const int16x8_t zero = vdupq_n_s16(0);
-  int16x8_t v_eobmax_76543210 = vreinterpretq_s16_u16(vceqq_s16(zero, zero));
-  int16x8_t round = vdupq_n_s16(rounding[1]);
-  int16x8_t quant = vdupq_n_s16(quant_ptr[1]);
-  int16x8_t dequant = vdupq_n_s16(dequant_ptr[1]);
-  dequant = vsetq_lane_s16(dequant_ptr[0], dequant, 0);
-
-  int16x8_t coeff = load_tran_low_to_s16q(&coeff_ptr[0]);
-
-  int16x8_t abs = vabsq_s16(coeff);
-  uint16x8_t check = vcgeq_s16(abs, vshrq_n_s16(dequant, 2));
-  uint64_t nz_check = vget_lane_u64(vreinterpret_u64_u8(vmovn_u16(check)), 0);
-  if (nz_check) {
-    const int16x8_t coeff_sign = vshrq_n_s16(coeff, 15);
-    const int16x8_t v_iscan = vld1q_s16(&iscan[0]);
-    round = vsetq_lane_s16(rounding[0], round, 0);
-    quant = vsetq_lane_s16(quant_ptr[0], quant, 0);
-
-    abs = vqaddq_s16(abs, round);
-    int16x8_t temp = vqdmulhq_s16(abs, quant);
-    int16x8_t qcoeff_temp = vsubq_s16(veorq_s16(temp, coeff_sign), coeff_sign);
-    abs = vreinterpretq_s16_u16(
-        vshrq_n_u16(vreinterpretq_u16_s16(vmulq_s16(temp, dequant)), 1));
-    int16x8_t dqcoeff_temp = vsubq_s16(veorq_s16(abs, coeff_sign), coeff_sign);
-
-    int16x8_t coeff_nz_mask =
-        vbslq_s16(check, qcoeff_temp, load_tran_low_to_s16q(&qcoeff_ptr[0]));
-    store_s16q_to_tran_low(&qcoeff_ptr[0], coeff_nz_mask);
-    coeff_nz_mask =
-        vbslq_s16(check, dqcoeff_temp, load_tran_low_to_s16q(&dqcoeff_ptr[0]));
-    store_s16q_to_tran_low(&dqcoeff_ptr[0], coeff_nz_mask);
-
-    round = vsetq_lane_s16(rounding[1], round, 0);
-    quant = vsetq_lane_s16(quant_ptr[1], quant, 0);
-
-    uint16x8_t vtmp_mask = vcgtq_s16(abs, zero);
-    const uint16x8_t v_nz_mask = vandq_u16(vtmp_mask, check);
-    check = vandq_u16(v_nz_mask, vcgtq_s16(v_iscan, v_eobmax_76543210));
-    v_eobmax_76543210 = vbslq_s16(check, v_iscan, v_eobmax_76543210);
-  }
-
-  dequant = vsetq_lane_s16(dequant_ptr[1], dequant, 0);
-
-  for (int i = 8; i < n_coeffs; i += 8) {
-    coeff = load_tran_low_to_s16q(&coeff_ptr[i]);
-    abs = vabsq_s16(coeff);
-    check = vcgeq_s16(abs, vshrq_n_s16(dequant, 2));
-
-    nz_check = vget_lane_u64(vreinterpret_u64_u8(vmovn_u16(check)), 0);
-    if (nz_check) {
-      const int16x8_t coeff_sign = vshrq_n_s16(coeff, 15);
-      const int16x8_t v_iscan = vld1q_s16(&iscan[i]);
-
-      abs = vqaddq_s16(abs, round);
-      int16x8_t temp = vqdmulhq_s16(abs, quant);
-      int16x8_t qcoeff_temp =
-          vsubq_s16(veorq_s16(temp, coeff_sign), coeff_sign);
-      abs = vreinterpretq_s16_u16(
-          vshrq_n_u16(vreinterpretq_u16_s16(vmulq_s16(temp, dequant)), 1));
-      int16x8_t dqcoeff_temp =
-          vsubq_s16(veorq_s16(abs, coeff_sign), coeff_sign);
-
-      int16x8_t coeff_nz_mask =
-          vbslq_s16(check, qcoeff_temp, load_tran_low_to_s16q(&qcoeff_ptr[i]));
-      store_s16q_to_tran_low(&qcoeff_ptr[i], coeff_nz_mask);
-      coeff_nz_mask = vbslq_s16(check, dqcoeff_temp,
-                                load_tran_low_to_s16q(&dqcoeff_ptr[i]));
-      store_s16q_to_tran_low(&dqcoeff_ptr[i], coeff_nz_mask);
-
-      uint16x8_t vtmp_mask = vcgtq_s16(abs, zero);
-      const uint16x8_t v_nz_mask = vandq_u16(vtmp_mask, check);
-      check = vandq_u16(v_nz_mask, vcgtq_s16(v_iscan, v_eobmax_76543210));
-      v_eobmax_76543210 = vbslq_s16(check, v_iscan, v_eobmax_76543210);
-    }
-  }
-  *eob_ptr = get_max_eob(v_eobmax_76543210) + 1;
+  quantize_fp_no_qmatrix_neon(coeff_ptr, n_coeffs, round_ptr, quant_ptr,
+                              qcoeff_ptr, dqcoeff_ptr, dequant_ptr, eob_ptr,
+                              iscan, 1);
 }
 
 void av1_quantize_fp_64x64_neon(const tran_low_t *coeff_ptr, intptr_t n_coeffs,