[arm]: Improve av1_quantize_fp_32x32_neon().
1.05x to 1.24x faster than the previous version
depending on the last nonzero coeff position.
Bug: b/231719821
Change-Id: I5695d8912a36419142ec4ff7e9df852816df94b6
diff --git a/av1/encoder/arm/neon/quantize_neon.c b/av1/encoder/arm/neon/quantize_neon.c
index 289218d..8b5888f 100644
--- a/av1/encoder/arm/neon/quantize_neon.c
+++ b/av1/encoder/arm/neon/quantize_neon.c
@@ -11,6 +11,7 @@
#include <arm_neon.h>
+#include <assert.h>
#include <math.h>
#include "aom_dsp/arm/mem_neon.h"
@@ -173,6 +174,115 @@
*eob_ptr = get_max_eob(v_eobmax_76543210);
}
+static INLINE uint16x8_t quantize_fp_logscale_8(
+ const tran_low_t *coeff_ptr, tran_low_t *qcoeff_ptr,
+ tran_low_t *dqcoeff_ptr, int16x8_t v_quant, int16x8_t v_dequant,
+ int16x8_t v_round, int16x8_t v_zero, int log_scale) {
+ const int16x8_t v_log_scale_minus_1 = vdupq_n_s16(log_scale - 1);
+ const int16x8_t v_neg_log_scale_plus_1 = vdupq_n_s16(-(1 + log_scale));
+ const int16x8_t v_coeff = load_tran_low_to_s16q(coeff_ptr);
+ const int16x8_t v_coeff_sign = vshrq_n_s16(v_coeff, 15);
+ const int16x8_t v_abs_coeff = vabsq_s16(v_coeff);
+ const uint16x8_t v_mask =
+ vcgeq_s16(v_abs_coeff, vshlq_s16(v_dequant, v_neg_log_scale_plus_1));
+ // const int64_t tmp = vmask ? (int64_t)abs_coeff + log_scaled_round : 0
+ const int16x8_t v_tmp = vandq_s16(vqaddq_s16(v_abs_coeff, v_round),
+ vreinterpretq_s16_u16(v_mask));
+ const int16x8_t v_tmp2 =
+ vqdmulhq_s16(vshlq_s16(v_tmp, v_log_scale_minus_1), v_quant);
+ const uint16x8_t v_nz_mask = vcgtq_s16(v_tmp2, v_zero);
+ const int16x8_t v_qcoeff =
+ vsubq_s16(veorq_s16(v_tmp2, v_coeff_sign), v_coeff_sign);
+ // Multiplying by dequant here will use all 16 bits. Cast to unsigned before
+ // shifting right. (vshlq_s16 will shift right if shift value is negative)
+ const uint16x8_t v_abs_dqcoeff =
+ vshlq_u16(vreinterpretq_u16_s16(vmulq_s16(v_tmp2, v_dequant)),
+ vdupq_n_s16(-log_scale));
+ const int16x8_t v_dqcoeff =
+ vsubq_s16(veorq_s16(vreinterpretq_s16_u16(v_abs_dqcoeff), v_coeff_sign),
+ v_coeff_sign);
+ store_s16q_to_tran_low(qcoeff_ptr, v_qcoeff);
+ store_s16q_to_tran_low(dqcoeff_ptr, v_dqcoeff);
+ return v_nz_mask;
+}
+
+static INLINE uint32_t sum_abs_coeff(const uint16x8_t a) {
+#if defined(__aarch64__)
+ return vaddvq_u16(a);
+#else
+ const uint32x4_t b = vpaddlq_u16(a);
+ const uint64x2_t c = vpaddlq_u32(b);
+ const uint64x1_t d = vadd_u64(vget_low_u64(c), vget_high_u64(c));
+ return (uint32_t)vget_lane_u64(d, 0);
+#endif
+}
+
+static void quantize_fp_no_qmatrix_neon(
+ const tran_low_t *coeff_ptr, intptr_t n_coeffs, const int16_t *round_ptr,
+ const int16_t *quant_ptr, tran_low_t *qcoeff_ptr, tran_low_t *dqcoeff_ptr,
+ const int16_t *dequant_ptr, uint16_t *eob_ptr, const int16_t *iscan,
+ int log_scale) {
+ const int16x8_t v_zero = vdupq_n_s16(0);
+ int16x8_t v_quant = vld1q_s16(quant_ptr);
+ int16x8_t v_dequant = vld1q_s16(dequant_ptr);
+ const int16x8_t v_round_no_scale = vld1q_s16(round_ptr);
+ int16x8_t v_round =
+ vqrdmulhq_n_s16(v_round_no_scale, (int16_t)(1 << (15 - log_scale)));
+ int16x8_t v_eobmax_76543210 = vdupq_n_s16(-1);
+ intptr_t non_zero_count = n_coeffs;
+
+ assert(n_coeffs > 16);
+ // Pre-scan pass
+ const int16x8_t v_dequant_scaled =
+ vshlq_s16(v_dequant, vdupq_n_s16(-(1 + log_scale)));
+ const int16x8_t v_zbin_s16 =
+ vdupq_lane_s16(vget_low_s16(v_dequant_scaled), 1);
+ intptr_t i = n_coeffs;
+ do {
+ const int16x8_t v_coeff_a = load_tran_low_to_s16q(coeff_ptr + i - 8);
+ const int16x8_t v_coeff_b = load_tran_low_to_s16q(coeff_ptr + i - 16);
+ const int16x8_t v_abs_coeff_a = vabsq_s16(v_coeff_a);
+ const int16x8_t v_abs_coeff_b = vabsq_s16(v_coeff_b);
+ const uint16x8_t v_mask_a = vcgeq_s16(v_abs_coeff_a, v_zbin_s16);
+ const uint16x8_t v_mask_b = vcgeq_s16(v_abs_coeff_b, v_zbin_s16);
+ // If the coefficient is in the base ZBIN range, then discard.
+ if (sum_abs_coeff(v_mask_a) + sum_abs_coeff(v_mask_b) == 0) {
+ non_zero_count -= 16;
+ } else {
+ break;
+ }
+ i -= 16;
+ } while (i > 0);
+
+ const intptr_t remaining_zcoeffs = n_coeffs - non_zero_count;
+ memset(qcoeff_ptr + non_zero_count, 0,
+ remaining_zcoeffs * sizeof(*qcoeff_ptr));
+ memset(dqcoeff_ptr + non_zero_count, 0,
+ remaining_zcoeffs * sizeof(*dqcoeff_ptr));
+
+ // process dc and the first seven ac coeffs
+ uint16x8_t v_nz_mask =
+ quantize_fp_logscale_8(coeff_ptr, qcoeff_ptr, dqcoeff_ptr, v_quant,
+ v_dequant, v_round, v_zero, log_scale);
+ v_eobmax_76543210 = get_max_lane_eob(iscan, v_eobmax_76543210, v_nz_mask);
+ // overwrite the dc constants with ac constants
+ v_quant = vdupq_lane_s16(vget_low_s16(v_quant), 1);
+ v_dequant = vdupq_lane_s16(vget_low_s16(v_dequant), 1);
+ v_round = vdupq_lane_s16(vget_low_s16(v_round), 1);
+
+ for (intptr_t count = non_zero_count - 8; count > 0; count -= 8) {
+ coeff_ptr += 8;
+ qcoeff_ptr += 8;
+ dqcoeff_ptr += 8;
+ iscan += 8;
+ v_nz_mask =
+ quantize_fp_logscale_8(coeff_ptr, qcoeff_ptr, dqcoeff_ptr, v_quant,
+ v_dequant, v_round, v_zero, log_scale);
+ v_eobmax_76543210 = get_max_lane_eob(iscan, v_eobmax_76543210, v_nz_mask);
+ }
+ *eob_ptr = get_max_eob(v_eobmax_76543210);
+}
+
void av1_quantize_fp_32x32_neon(const tran_low_t *coeff_ptr, intptr_t n_coeffs,
const int16_t *zbin_ptr,
const int16_t *round_ptr,
@@ -181,93 +291,12 @@
tran_low_t *qcoeff_ptr, tran_low_t *dqcoeff_ptr,
const int16_t *dequant_ptr, uint16_t *eob_ptr,
const int16_t *scan, const int16_t *iscan) {
- const int log_scale = 1;
- const int rounding[2] = { ROUND_POWER_OF_TWO(round_ptr[0], log_scale),
- ROUND_POWER_OF_TWO(round_ptr[1], log_scale) };
-
(void)zbin_ptr;
(void)quant_shift_ptr;
(void)scan;
-
- memset(qcoeff_ptr, 0, n_coeffs * sizeof(*qcoeff_ptr));
- memset(dqcoeff_ptr, 0, n_coeffs * sizeof(*dqcoeff_ptr));
-
- const int16x8_t zero = vdupq_n_s16(0);
- int16x8_t v_eobmax_76543210 = vreinterpretq_s16_u16(vceqq_s16(zero, zero));
- int16x8_t round = vdupq_n_s16(rounding[1]);
- int16x8_t quant = vdupq_n_s16(quant_ptr[1]);
- int16x8_t dequant = vdupq_n_s16(dequant_ptr[1]);
- dequant = vsetq_lane_s16(dequant_ptr[0], dequant, 0);
-
- int16x8_t coeff = load_tran_low_to_s16q(&coeff_ptr[0]);
-
- int16x8_t abs = vabsq_s16(coeff);
- uint16x8_t check = vcgeq_s16(abs, vshrq_n_s16(dequant, 2));
- uint64_t nz_check = vget_lane_u64(vreinterpret_u64_u8(vmovn_u16(check)), 0);
- if (nz_check) {
- const int16x8_t coeff_sign = vshrq_n_s16(coeff, 15);
- const int16x8_t v_iscan = vld1q_s16(&iscan[0]);
- round = vsetq_lane_s16(rounding[0], round, 0);
- quant = vsetq_lane_s16(quant_ptr[0], quant, 0);
-
- abs = vqaddq_s16(abs, round);
- int16x8_t temp = vqdmulhq_s16(abs, quant);
- int16x8_t qcoeff_temp = vsubq_s16(veorq_s16(temp, coeff_sign), coeff_sign);
- abs = vreinterpretq_s16_u16(
- vshrq_n_u16(vreinterpretq_u16_s16(vmulq_s16(temp, dequant)), 1));
- int16x8_t dqcoeff_temp = vsubq_s16(veorq_s16(abs, coeff_sign), coeff_sign);
-
- int16x8_t coeff_nz_mask =
- vbslq_s16(check, qcoeff_temp, load_tran_low_to_s16q(&qcoeff_ptr[0]));
- store_s16q_to_tran_low(&qcoeff_ptr[0], coeff_nz_mask);
- coeff_nz_mask =
- vbslq_s16(check, dqcoeff_temp, load_tran_low_to_s16q(&dqcoeff_ptr[0]));
- store_s16q_to_tran_low(&dqcoeff_ptr[0], coeff_nz_mask);
-
- round = vsetq_lane_s16(rounding[1], round, 0);
- quant = vsetq_lane_s16(quant_ptr[1], quant, 0);
-
- uint16x8_t vtmp_mask = vcgtq_s16(abs, zero);
- const uint16x8_t v_nz_mask = vandq_u16(vtmp_mask, check);
- check = vandq_u16(v_nz_mask, vcgtq_s16(v_iscan, v_eobmax_76543210));
- v_eobmax_76543210 = vbslq_s16(check, v_iscan, v_eobmax_76543210);
- }
-
- dequant = vsetq_lane_s16(dequant_ptr[1], dequant, 0);
-
- for (int i = 8; i < n_coeffs; i += 8) {
- coeff = load_tran_low_to_s16q(&coeff_ptr[i]);
- abs = vabsq_s16(coeff);
- check = vcgeq_s16(abs, vshrq_n_s16(dequant, 2));
-
- nz_check = vget_lane_u64(vreinterpret_u64_u8(vmovn_u16(check)), 0);
- if (nz_check) {
- const int16x8_t coeff_sign = vshrq_n_s16(coeff, 15);
- const int16x8_t v_iscan = vld1q_s16(&iscan[i]);
-
- abs = vqaddq_s16(abs, round);
- int16x8_t temp = vqdmulhq_s16(abs, quant);
- int16x8_t qcoeff_temp =
- vsubq_s16(veorq_s16(temp, coeff_sign), coeff_sign);
- abs = vreinterpretq_s16_u16(
- vshrq_n_u16(vreinterpretq_u16_s16(vmulq_s16(temp, dequant)), 1));
- int16x8_t dqcoeff_temp =
- vsubq_s16(veorq_s16(abs, coeff_sign), coeff_sign);
-
- int16x8_t coeff_nz_mask =
- vbslq_s16(check, qcoeff_temp, load_tran_low_to_s16q(&qcoeff_ptr[i]));
- store_s16q_to_tran_low(&qcoeff_ptr[i], coeff_nz_mask);
- coeff_nz_mask = vbslq_s16(check, dqcoeff_temp,
- load_tran_low_to_s16q(&dqcoeff_ptr[i]));
- store_s16q_to_tran_low(&dqcoeff_ptr[i], coeff_nz_mask);
-
- uint16x8_t vtmp_mask = vcgtq_s16(abs, zero);
- const uint16x8_t v_nz_mask = vandq_u16(vtmp_mask, check);
- check = vandq_u16(v_nz_mask, vcgtq_s16(v_iscan, v_eobmax_76543210));
- v_eobmax_76543210 = vbslq_s16(check, v_iscan, v_eobmax_76543210);
- }
- }
- *eob_ptr = get_max_eob(v_eobmax_76543210) + 1;
+ quantize_fp_no_qmatrix_neon(coeff_ptr, n_coeffs, round_ptr, quant_ptr,
+ qcoeff_ptr, dqcoeff_ptr, dequant_ptr, eob_ptr,
+ iscan, 1);
}
void av1_quantize_fp_64x64_neon(const tran_low_t *coeff_ptr, intptr_t n_coeffs,