Optimize aom_quantize_b functions Optimize aom_quantize_b_neon, aom_quantize_b_32x32_neon, aom_quantize_b_64x64_neon by removing the initial memset function calls that write zeroes to qcoeff_ptr and dqcoeff_ptr. Instead, write those zero entries inside the processing loops. In av1_quantize.c the aom_quantize_b functions are called with NULL values for qm_ptr and iqm_ptr. The unit tests also have the same behaviour. When these pointers are not NULL, the aom_quantize_b_helper_c function is called instead. Remove the NULL checks for these pointer inside the quantize_b functions. Change-Id: I0032ac5346992fcb949ecaddf8214e3d0b62efd8
diff --git a/av1/encoder/arm/quantize_neon.c b/av1/encoder/arm/quantize_neon.c index 8d13e57..e9c0ca7 100644 --- a/av1/encoder/arm/quantize_neon.c +++ b/av1/encoder/arm/quantize_neon.c
@@ -351,6 +351,33 @@ iscan, 2); } +static inline uint16x8_t quantize_b_logscale0_8( + int16x8_t coeff, int16x8_t abs, uint16x8_t cond, int16x8_t round, + int16x8_t dequant, int16x8_t quant, int16x8_t quant_shift, + tran_low_t *qcoeff_ptr, tran_low_t *dqcoeff_ptr) { + const int16x8_t zero = vdupq_n_s16(0); + + int16x8_t coeff_sign = vreinterpretq_s16_u16(vcltq_s16(coeff, zero)); + + int16x8_t tmp = vqaddq_s16(abs, round); + tmp = vsraq_n_s16(tmp, vqdmulhq_s16(tmp, quant), 1); + tmp = vqdmulhq_s16(tmp, quant_shift); + + int16x8_t qcoeff = vsubq_s16(veorq_s16(tmp, coeff_sign), coeff_sign); + qcoeff = vbslq_s16(cond, qcoeff, zero); + store_s16q_to_tran_low(qcoeff_ptr, qcoeff); + + int16x8_t dqcoeff = vmulq_s16(tmp, dequant); + dqcoeff = vsubq_s16(veorq_s16(dqcoeff, coeff_sign), coeff_sign); + dqcoeff = vbslq_s16(cond, dqcoeff, zero); + store_s16q_to_tran_low(dqcoeff_ptr, dqcoeff); + + uint16x8_t tmp_mask = vcgtq_s16(tmp, zero); + uint16x8_t nz_mask = vandq_u16(tmp_mask, cond); + + return nz_mask; +} + void aom_quantize_b_neon(const tran_low_t *coeff_ptr, intptr_t n_coeffs, const int16_t *zbin_ptr, const int16_t *round_ptr, const int16_t *quant_ptr, @@ -361,88 +388,58 @@ (void)quant_shift_ptr; (void)scan; - const int zbins[2] = { zbin_ptr[0], zbin_ptr[1] }; + int16x8_t v_eobmax_76543210 = vdupq_n_s16(-1); - memset(qcoeff_ptr, 0, n_coeffs * sizeof(*qcoeff_ptr)); - memset(dqcoeff_ptr, 0, n_coeffs * sizeof(*dqcoeff_ptr)); + int16x8_t v_zbins = vdupq_n_s16(zbin_ptr[1]); + int16x8_t v_round = vdupq_n_s16(round_ptr[1]); + int16x8_t v_dequant = vdupq_n_s16(dequant_ptr[1]); + int16x8_t v_quant = vdupq_n_s16(quant_ptr[1]); + // Shift by 1 in order to save one shift in the kernel function. + int16x8_t v_quant_shift = vdupq_n_s16(quant_shift_ptr[1] >> 1); - const int16x8_t zero = vdupq_n_s16(0); - int16x8_t v_eobmax_76543210 = vreinterpretq_s16_u16(vceqq_s16(zero, zero)); - - int16x8_t vzbins = vdupq_n_s16(zbins[1]), vround = vdupq_n_s16(round_ptr[1]); - int16x8_t vdequant = vdupq_n_s16(dequant_ptr[1]); - int16x8_t vquant = vdupq_n_s16(quant_ptr[1]); - int16x8_t vquant_shift = vdupq_n_s16(quant_shift_ptr[1]); - - int16x8_t v_coeff = load_tran_low_to_s16q(&coeff_ptr[0]); - int16x8_t v_coeff_sign = vshrq_n_s16(v_coeff, 15); + int16x8_t v_zbins0 = vsetq_lane_s16(zbin_ptr[0], v_zbins, 0); + int16x8_t v_coeff = load_tran_low_to_s16q(coeff_ptr); int16x8_t v_abs = vabsq_s16(v_coeff); + uint16x8_t v_cond = vcgeq_s16(v_abs, v_zbins0); - vzbins = vsetq_lane_s16(zbins[0], vzbins, 0); - - uint16x8_t vcond = vcgeq_s16(v_abs, vzbins); - uint64_t nz_check = vget_lane_u64(vreinterpret_u64_u8(vmovn_u16(vcond)), 0); + uint64_t nz_check = vget_lane_u64(vreinterpret_u64_u8(vmovn_u16(v_cond)), 0); if (nz_check) { - vround = vsetq_lane_s16(round_ptr[0], vround, 0); - vquant = vsetq_lane_s16(quant_ptr[0], vquant, 0); - vdequant = vsetq_lane_s16(dequant_ptr[0], vdequant, 0); - vquant_shift = vsetq_lane_s16(quant_shift_ptr[0], vquant_shift, 0); + int16x8_t v_round0 = vsetq_lane_s16(round_ptr[0], v_round, 0); + int16x8_t v_quant0 = vsetq_lane_s16(quant_ptr[0], v_quant, 0); + int16x8_t v_dequant0 = vsetq_lane_s16(dequant_ptr[0], v_dequant, 0); + // Shift by 1 in order to save one shift in the kernel function. + int16x8_t v_quant_shift0 = + vsetq_lane_s16(quant_shift_ptr[0] >> 1, v_quant_shift, 0); - int16x8_t vtmp = vqaddq_s16(v_abs, vround); - int16x8_t vtmp2 = vsraq_n_s16(vtmp, vqdmulhq_s16(vtmp, vquant), 1); - vtmp2 = vshrq_n_s16(vqdmulhq_s16(vtmp2, vquant_shift), 1); + const uint16x8_t v_nz_mask = quantize_b_logscale0_8( + v_coeff, v_abs, v_cond, v_round0, v_dequant0, v_quant0, v_quant_shift0, + qcoeff_ptr, dqcoeff_ptr); - int16x8_t vdest = vsubq_s16(veorq_s16(vtmp2, v_coeff_sign), v_coeff_sign); - int16x8_t coeff_nz_mask = - vbslq_s16(vcond, vdest, load_tran_low_to_s16q(&qcoeff_ptr[0])); - store_s16q_to_tran_low(&qcoeff_ptr[0], coeff_nz_mask); - int16x8_t v_deq_abs = vmulq_s16(vtmp2, vdequant); - - vdest = vsubq_s16(veorq_s16(v_deq_abs, v_coeff_sign), v_coeff_sign); - coeff_nz_mask = - vbslq_s16(vcond, vdest, load_tran_low_to_s16q(&dqcoeff_ptr[0])); - store_s16q_to_tran_low(&dqcoeff_ptr[0], coeff_nz_mask); - - vround = vsetq_lane_s16(round_ptr[1], vround, 0); - vquant = vsetq_lane_s16(quant_ptr[1], vquant, 0); - vdequant = vsetq_lane_s16(dequant_ptr[1], vdequant, 0); - vquant_shift = vsetq_lane_s16(quant_shift_ptr[1], vquant_shift, 0); - - uint16x8_t vtmp_mask = vcgtq_s16(vtmp2, zero); - const uint16x8_t v_nz_mask = vandq_u16(vtmp_mask, vcond); - int16x8_t v_iscan = vld1q_s16(&iscan[0]); - vcond = vandq_u16(v_nz_mask, vcgtq_s16(v_iscan, v_eobmax_76543210)); - v_eobmax_76543210 = vbslq_s16(vcond, v_iscan, v_eobmax_76543210); + int16x8_t v_iscan = vld1q_s16(iscan); + int16x8_t v_eobmax = vmaxq_s16(v_iscan, v_eobmax_76543210); + v_eobmax_76543210 = vbslq_s16(v_nz_mask, v_eobmax, v_eobmax_76543210); + } else { + store_s16q_to_tran_low(qcoeff_ptr, vdupq_n_s16(0)); + store_s16q_to_tran_low(dqcoeff_ptr, vdupq_n_s16(0)); } - vzbins = vsetq_lane_s16(zbins[1], vzbins, 0); for (int i = 8; i < n_coeffs; i += 8) { - v_coeff = load_tran_low_to_s16q(&coeff_ptr[i]); - v_coeff_sign = vshrq_n_s16(v_coeff, 15); + v_coeff = load_tran_low_to_s16q(coeff_ptr + i); v_abs = vabsq_s16(v_coeff); - vcond = vcgeq_s16(v_abs, vzbins); + v_cond = vcgeq_s16(v_abs, v_zbins); - nz_check = vget_lane_u64(vreinterpret_u64_u8(vmovn_u16(vcond)), 0); + nz_check = vget_lane_u64(vreinterpret_u64_u8(vmovn_u16(v_cond)), 0); if (nz_check) { - int16x8_t vtmp = vqaddq_s16(v_abs, vround); - int16x8_t vtmp2 = vsraq_n_s16(vtmp, vqdmulhq_s16(vtmp, vquant), 1); + const uint16x8_t v_nz_mask = quantize_b_logscale0_8( + v_coeff, v_abs, v_cond, v_round, v_dequant, v_quant, v_quant_shift, + qcoeff_ptr + i, dqcoeff_ptr + i); - vtmp2 = vshrq_n_s16(vqdmulhq_s16(vtmp2, vquant_shift), 1); - int16x8_t vdest = vsubq_s16(veorq_s16(vtmp2, v_coeff_sign), v_coeff_sign); - int16x8_t coeff_nz_mask = - vbslq_s16(vcond, vdest, load_tran_low_to_s16q(&qcoeff_ptr[i])); - store_s16q_to_tran_low(&qcoeff_ptr[i], coeff_nz_mask); - int16x8_t v_deq_abs = vmulq_s16(vtmp2, vdequant); - vdest = vsubq_s16(veorq_s16(v_deq_abs, v_coeff_sign), v_coeff_sign); - coeff_nz_mask = - vbslq_s16(vcond, vdest, load_tran_low_to_s16q(&dqcoeff_ptr[i])); - store_s16q_to_tran_low(&dqcoeff_ptr[i], coeff_nz_mask); - - uint16x8_t vtmp_mask = vcgtq_s16(vtmp2, zero); - const uint16x8_t v_nz_mask = vandq_u16(vtmp_mask, vcond); - int16x8_t v_iscan = vld1q_s16(&iscan[i]); - vcond = vandq_u16(v_nz_mask, vcgtq_s16(v_iscan, v_eobmax_76543210)); - v_eobmax_76543210 = vbslq_s16(vcond, v_iscan, v_eobmax_76543210); + int16x8_t v_iscan = vld1q_s16(iscan + i); + int16x8_t v_eobmax = vmaxq_s16(v_iscan, v_eobmax_76543210); + v_eobmax_76543210 = vbslq_s16(v_nz_mask, v_eobmax, v_eobmax_76543210); + } else { + store_s16q_to_tran_low(qcoeff_ptr + i, vdupq_n_s16(0)); + store_s16q_to_tran_low(dqcoeff_ptr + i, vdupq_n_s16(0)); } } *eob_ptr = get_max_eob(v_eobmax_76543210); @@ -899,6 +896,35 @@ } } +static inline uint16x8_t quantize_b_logscale1_8( + int16x8_t coeff, int16x8_t abs, uint16x8_t cond, int16x8_t round, + int16x8_t dequant, int16x8_t quant, int16x8_t quant_shift, + tran_low_t *qcoeff_ptr, tran_low_t *dqcoeff_ptr) { + const int16x8_t zero = vdupq_n_s16(0); + + int16x8_t coeff_sign = vreinterpretq_s16_u16(vcltq_s16(coeff, zero)); + + int16x8_t tmp = vqaddq_s16(abs, round); + tmp = vsraq_n_s16(tmp, vqdmulhq_s16(tmp, quant), 1); + tmp = vqdmulhq_s16(tmp, quant_shift); + + int16x8_t qcoeff = vsubq_s16(veorq_s16(tmp, coeff_sign), coeff_sign); + qcoeff = vbslq_s16(cond, qcoeff, zero); + store_s16q_to_tran_low(qcoeff_ptr, qcoeff); + + // Shift by log_scale = 1. + int16x8_t dqcoeff = vreinterpretq_s16_u16(vhaddq_u16( + vreinterpretq_u16_s16(vmulq_s16(tmp, dequant)), vdupq_n_u16(0))); + dqcoeff = vsubq_s16(veorq_s16(dqcoeff, coeff_sign), coeff_sign); + dqcoeff = vbslq_s16(cond, dqcoeff, zero); + store_s16q_to_tran_low(dqcoeff_ptr, dqcoeff); + + uint16x8_t tmp_mask = vcgtq_s16(tmp, zero); + const uint16x8_t nz_mask = vandq_u16(tmp_mask, cond); + + return nz_mask; +} + void aom_quantize_b_32x32_neon(const tran_low_t *coeff_ptr, intptr_t n_coeffs, const int16_t *zbin_ptr, const int16_t *round_ptr, @@ -907,10 +933,100 @@ tran_low_t *qcoeff_ptr, tran_low_t *dqcoeff_ptr, const int16_t *dequant_ptr, uint16_t *eob_ptr, const int16_t *scan, const int16_t *iscan) { - aom_quantize_b_helper_neon(coeff_ptr, n_coeffs, zbin_ptr, round_ptr, - quant_ptr, quant_shift_ptr, qcoeff_ptr, - dqcoeff_ptr, dequant_ptr, eob_ptr, scan, iscan, - NULL, NULL, 1); + (void)scan; + + const int log_scale = 1; + const int zbins[2] = { ROUND_POWER_OF_TWO(zbin_ptr[0], log_scale), + ROUND_POWER_OF_TWO(zbin_ptr[1], log_scale) }; + const int rounds[2] = { ROUND_POWER_OF_TWO(round_ptr[0], log_scale), + ROUND_POWER_OF_TWO(round_ptr[1], log_scale) }; + + int16x8_t v_eobmax_76543210 = vdupq_n_s16(-1); + + int16x8_t v_zbins = vdupq_n_s16(zbins[1]); + int16x8_t v_round = vdupq_n_s16(rounds[1]); + int16x8_t v_dequant = vdupq_n_s16(dequant_ptr[1]); + int16x8_t v_quant = vdupq_n_s16(quant_ptr[1]); + int16x8_t v_quant_shift = vdupq_n_s16(quant_shift_ptr[1]); + + int16x8_t v_zbins0 = vsetq_lane_s16(zbins[0], v_zbins, 0); + int16x8_t v_coeff = load_tran_low_to_s16q(coeff_ptr); + int16x8_t v_abs = vabsq_s16(v_coeff); + uint16x8_t v_cond = vcgeq_s16(v_abs, v_zbins0); + + uint64_t nz_check = vget_lane_u64(vreinterpret_u64_u8(vmovn_u16(v_cond)), 0); + if (nz_check) { + int16x8_t v_round0 = vsetq_lane_s16(rounds[0], v_round, 0); + int16x8_t v_quant0 = vsetq_lane_s16(quant_ptr[0], v_quant, 0); + int16x8_t v_dequant0 = vsetq_lane_s16(dequant_ptr[0], v_dequant, 0); + int16x8_t v_quant_shift0 = + vsetq_lane_s16(quant_shift_ptr[0], v_quant_shift, 0); + + const uint16x8_t v_nz_mask = quantize_b_logscale1_8( + v_coeff, v_abs, v_cond, v_round0, v_dequant0, v_quant0, v_quant_shift0, + qcoeff_ptr, dqcoeff_ptr); + + int16x8_t v_iscan = vld1q_s16(iscan); + int16x8_t v_eobmax = vmaxq_s16(v_iscan, v_eobmax_76543210); + v_eobmax_76543210 = vbslq_s16(v_nz_mask, v_eobmax, v_eobmax_76543210); + } else { + store_s16q_to_tran_low(qcoeff_ptr, vdupq_n_s16(0)); + store_s16q_to_tran_low(dqcoeff_ptr, vdupq_n_s16(0)); + } + + for (int i = 8; i < n_coeffs; i += 8) { + v_coeff = load_tran_low_to_s16q(coeff_ptr + i); + v_abs = vabsq_s16(v_coeff); + v_cond = vcgeq_s16(v_abs, v_zbins); + + nz_check = vget_lane_u64(vreinterpret_u64_u8(vmovn_u16(v_cond)), 0); + if (nz_check) { + const uint16x8_t v_nz_mask = quantize_b_logscale1_8( + v_coeff, v_abs, v_cond, v_round, v_dequant, v_quant, v_quant_shift, + qcoeff_ptr + i, dqcoeff_ptr + i); + + int16x8_t v_iscan = vld1q_s16(iscan + i); + int16x8_t v_eobmax = vmaxq_s16(v_iscan, v_eobmax_76543210); + v_eobmax_76543210 = vbslq_s16(v_nz_mask, v_eobmax, v_eobmax_76543210); + } else { + store_s16q_to_tran_low(qcoeff_ptr + i, vdupq_n_s16(0)); + store_s16q_to_tran_low(dqcoeff_ptr + i, vdupq_n_s16(0)); + } + } + *eob_ptr = get_max_eob(v_eobmax_76543210); +} + +static inline uint16x8_t quantize_b_logscale2_8( + int16x8_t coeff, int16x8_t abs, uint16x8_t cond, int16x8_t round, + int16x8_t dequant, int16x8_t quant, int16x8_t quant_shift, + tran_low_t *qcoeff_ptr, tran_low_t *dqcoeff_ptr) { + const int16x8_t zero = vdupq_n_s16(0); + const int16x8_t one = vdupq_n_s16(1); + + int16x8_t coeff_sign = vreinterpretq_s16_u16(vcltq_s16(coeff, zero)); + + int16x8_t tmp = vqaddq_s16(abs, round); + tmp = vsraq_n_s16(tmp, vqdmulhq_s16(tmp, quant), 1); + int16x8_t ones = vandq_s16(vshrq_n_s16(vmulq_s16(tmp, quant_shift), 14), one); + tmp = vqdmulhq_s16(tmp, quant_shift); + tmp = vaddq_s16(vshlq_s16(tmp, one), ones); + + int16x8_t qcoeff = vsubq_s16(veorq_s16(tmp, coeff_sign), coeff_sign); + qcoeff = vbslq_s16(cond, qcoeff, zero); + store_s16q_to_tran_low(qcoeff_ptr, qcoeff); + + // Shift right by log_scale = 2. + int16x8_t dqcoeff = vreinterpretq_s16_u16( + vshrq_n_u16(vreinterpretq_u16_s16(vmulq_s16(tmp, dequant)), 2)); + dqcoeff = vorrq_s16(vshlq_n_s16(vqdmulhq_s16(tmp, dequant), 13), dqcoeff); + dqcoeff = vsubq_s16(veorq_s16(dqcoeff, coeff_sign), coeff_sign); + dqcoeff = vbslq_s16(cond, dqcoeff, zero); + store_s16q_to_tran_low(dqcoeff_ptr, dqcoeff); + + uint16x8_t tmp_mask = vcgtq_s16(tmp, zero); + const uint16x8_t nz_mask = vandq_u16(tmp_mask, cond); + + return nz_mask; } void aom_quantize_b_64x64_neon(const tran_low_t *coeff_ptr, intptr_t n_coeffs, @@ -921,8 +1037,65 @@ tran_low_t *qcoeff_ptr, tran_low_t *dqcoeff_ptr, const int16_t *dequant_ptr, uint16_t *eob_ptr, const int16_t *scan, const int16_t *iscan) { - aom_quantize_b_helper_neon(coeff_ptr, n_coeffs, zbin_ptr, round_ptr, - quant_ptr, quant_shift_ptr, qcoeff_ptr, - dqcoeff_ptr, dequant_ptr, eob_ptr, scan, iscan, - NULL, NULL, 2); + (void)scan; + + const int log_scale = 2; + const int zbins[2] = { ROUND_POWER_OF_TWO(zbin_ptr[0], log_scale), + ROUND_POWER_OF_TWO(zbin_ptr[1], log_scale) }; + const int rounds[2] = { ROUND_POWER_OF_TWO(round_ptr[0], log_scale), + ROUND_POWER_OF_TWO(round_ptr[1], log_scale) }; + + int16x8_t v_eobmax_76543210 = vdupq_n_s16(-1); + + int16x8_t v_zbins = vdupq_n_s16(zbins[1]); + int16x8_t v_round = vdupq_n_s16(rounds[1]); + int16x8_t v_dequant = vdupq_n_s16(dequant_ptr[1]); + int16x8_t v_quant = vdupq_n_s16(quant_ptr[1]); + int16x8_t v_quant_shift = vdupq_n_s16(quant_shift_ptr[1]); + + int16x8_t v_zbins0 = vsetq_lane_s16(zbins[0], v_zbins, 0); + int16x8_t v_coeff = load_tran_low_to_s16q(coeff_ptr); + int16x8_t v_abs = vabsq_s16(v_coeff); + uint16x8_t v_cond = vcgeq_s16(v_abs, v_zbins0); + + uint64_t nz_check = vget_lane_u64(vreinterpret_u64_u8(vmovn_u16(v_cond)), 0); + if (nz_check) { + int16x8_t v_round0 = vsetq_lane_s16(rounds[0], v_round, 0); + int16x8_t v_quant0 = vsetq_lane_s16(quant_ptr[0], v_quant, 0); + int16x8_t v_dequant0 = vsetq_lane_s16(dequant_ptr[0], v_dequant, 0); + int16x8_t v_quant_shift0 = + vsetq_lane_s16(quant_shift_ptr[0], v_quant_shift, 0); + + const uint16x8_t v_nz_mask = quantize_b_logscale2_8( + v_coeff, v_abs, v_cond, v_round0, v_dequant0, v_quant0, v_quant_shift0, + qcoeff_ptr, dqcoeff_ptr); + + int16x8_t v_iscan = vld1q_s16(iscan); + int16x8_t v_eobmax = vmaxq_s16(v_iscan, v_eobmax_76543210); + v_eobmax_76543210 = vbslq_s16(v_nz_mask, v_eobmax, v_eobmax_76543210); + } else { + store_s16q_to_tran_low(qcoeff_ptr, vdupq_n_s16(0)); + store_s16q_to_tran_low(dqcoeff_ptr, vdupq_n_s16(0)); + } + + for (int i = 8; i < n_coeffs; i += 8) { + v_coeff = load_tran_low_to_s16q(coeff_ptr + i); + v_abs = vabsq_s16(v_coeff); + v_cond = vcgeq_s16(v_abs, v_zbins); + + nz_check = vget_lane_u64(vreinterpret_u64_u8(vmovn_u16(v_cond)), 0); + if (nz_check) { + const uint16x8_t v_nz_mask = quantize_b_logscale2_8( + v_coeff, v_abs, v_cond, v_round, v_dequant, v_quant, v_quant_shift, + qcoeff_ptr + i, dqcoeff_ptr + i); + + int16x8_t v_iscan = vld1q_s16(iscan + i); + int16x8_t v_eobmax = vmaxq_s16(v_iscan, v_eobmax_76543210); + v_eobmax_76543210 = vbslq_s16(v_nz_mask, v_eobmax, v_eobmax_76543210); + } else { + store_s16q_to_tran_low(qcoeff_ptr + i, vdupq_n_s16(0)); + store_s16q_to_tran_low(dqcoeff_ptr + i, vdupq_n_s16(0)); + } + } + *eob_ptr = get_max_eob(v_eobmax_76543210); }