Count down in Neon variance loops Counting down and terminating on a zero comparison allows us to use flag-setting arithmetic instructions, avoiding an additional CMP instruction before the conditional branch. This doesn't really affect performance but code size is reduced by a useful amount since loop prologues are also shorter - especially with older compilers. Change-Id: I3f75319654c13006591ce861f5a0c2d3230a5e32
diff --git a/aom_dsp/arm/variance_neon.c b/aom_dsp/arm/variance_neon.c index 40e40f0..e2bc96b 100644 --- a/aom_dsp/arm/variance_neon.c +++ b/aom_dsp/arm/variance_neon.c
@@ -27,7 +27,7 @@ uint32x4_t ref_sum = vdupq_n_u32(0); uint32x4_t sse_u32 = vdupq_n_u32(0); - int i = 0; + int i = h; do { uint8x16_t s = load_unaligned_u8q(src, src_stride); uint8x16_t r = load_unaligned_u8q(ref, ref_stride); @@ -40,8 +40,8 @@ src += 4 * src_stride; ref += 4 * ref_stride; - i += 4; - } while (i < h); + i -= 4; + } while (i != 0); int32x4_t sum_diff = vsubq_s32(vreinterpretq_s32_u32(src_sum), vreinterpretq_s32_u32(ref_sum)); @@ -56,7 +56,7 @@ uint32x4_t ref_sum = vdupq_n_u32(0); uint32x4_t sse_u32 = vdupq_n_u32(0); - int i = 0; + int i = h; do { uint8x16_t s = vcombine_u8(vld1_u8(src), vld1_u8(src + src_stride)); uint8x16_t r = vcombine_u8(vld1_u8(ref), vld1_u8(ref + ref_stride)); @@ -69,8 +69,8 @@ src += 2 * src_stride; ref += 2 * ref_stride; - i += 2; - } while (i < h); + i -= 2; + } while (i != 0); int32x4_t sum_diff = vsubq_s32(vreinterpretq_s32_u32(src_sum), vreinterpretq_s32_u32(ref_sum)); @@ -85,7 +85,7 @@ uint32x4_t ref_sum = vdupq_n_u32(0); uint32x4_t sse_u32 = vdupq_n_u32(0); - int i = 0; + int i = h; do { uint8x16_t s = vld1q_u8(src); uint8x16_t r = vld1q_u8(ref); @@ -98,8 +98,7 @@ src += src_stride; ref += ref_stride; - i++; - } while (i < h); + } while (--i != 0); int32x4_t sum_diff = vsubq_s32(vreinterpretq_s32_u32(src_sum), vreinterpretq_s32_u32(ref_sum)); @@ -114,7 +113,7 @@ uint32x4_t ref_sum = vdupq_n_u32(0); uint32x4_t sse_u32 = vdupq_n_u32(0); - int i = 0; + int i = h; do { int j = 0; do { @@ -132,8 +131,7 @@ src += src_stride; ref += ref_stride; - i++; - } while (i < h); + } while (--i != 0); int32x4_t sum_diff = vsubq_s32(vreinterpretq_s32_u32(src_sum), vreinterpretq_s32_u32(ref_sum)); @@ -171,7 +169,7 @@ // 32767 / 255 ~= 128, but we use an 8-wide accumulator; so 256 4-wide rows. assert(h <= 256); - int i = 0; + int i = h; do { uint8x8_t s = load_unaligned_u8(src, src_stride); uint8x8_t r = load_unaligned_u8(ref, ref_stride); @@ -184,8 +182,8 @@ src += 2 * src_stride; ref += 2 * ref_stride; - i += 2; - } while (i < h); + i -= 2; + } while (i != 0); *sum = horizontal_add_s16x8(sum_s16); *sse = (uint32_t)horizontal_add_s32x4(sse_s32); @@ -201,7 +199,7 @@ // 32767 / 255 ~= 128 assert(h <= 128); - int i = 0; + int i = h; do { uint8x8_t s = vld1_u8(src); uint8x8_t r = vld1_u8(ref); @@ -215,8 +213,7 @@ src += src_stride; ref += ref_stride; - i++; - } while (i < h); + } while (--i != 0); *sum = horizontal_add_s16x8(sum_s16); *sse = (uint32_t)horizontal_add_s32x4(vaddq_s32(sse_s32[0], sse_s32[1])); @@ -232,7 +229,7 @@ // 32767 / 255 ~= 128, so 128 16-wide rows. assert(h <= 128); - int i = 0; + int i = h; do { uint8x16_t s = vld1q_u8(src); uint8x16_t r = vld1q_u8(ref); @@ -256,8 +253,7 @@ src += src_stride; ref += ref_stride; - i++; - } while (i < h); + } while (--i != 0); *sum = horizontal_add_s16x8(vaddq_s16(sum_s16[0], sum_s16[1])); *sse = (uint32_t)horizontal_add_s32x4(vaddq_s32(sse_s32[0], sse_s32[1])); @@ -416,7 +412,7 @@ unsigned int *sse, int h) { uint32x4_t sse_u32 = vdupq_n_u32(0); - int i = 0; + int i = h; do { uint8x16_t s = vcombine_u8(vld1_u8(src), vld1_u8(src + src_stride)); uint8x16_t r = vcombine_u8(vld1_u8(ref), vld1_u8(ref + ref_stride)); @@ -427,8 +423,8 @@ src += 2 * src_stride; ref += 2 * ref_stride; - i += 2; - } while (i < h); + i -= 2; + } while (i != 0); *sse = horizontal_add_u32x4(sse_u32); return horizontal_add_u32x4(sse_u32); @@ -439,7 +435,7 @@ unsigned int *sse, int h) { uint32x4_t sse_u32[2] = { vdupq_n_u32(0), vdupq_n_u32(0) }; - int i = 0; + int i = h; do { uint8x16_t s0 = vld1q_u8(src); uint8x16_t s1 = vld1q_u8(src + src_stride); @@ -454,8 +450,8 @@ src += 2 * src_stride; ref += 2 * ref_stride; - i += 2; - } while (i < h); + i -= 2; + } while (i != 0); *sse = horizontal_add_u32x4(vaddq_u32(sse_u32[0], sse_u32[1])); return horizontal_add_u32x4(vaddq_u32(sse_u32[0], sse_u32[1])); @@ -483,7 +479,7 @@ uint16x8_t diff[2]; int32x4_t sse_s32[2] = { vdupq_n_s32(0), vdupq_n_s32(0) }; - int i = 0; + int i = h; do { s[0] = vld1_u8(src); src += src_stride; @@ -507,8 +503,8 @@ sse_s32[0] = vmlal_s16(sse_s32[0], diff_hi[0], diff_hi[0]); sse_s32[1] = vmlal_s16(sse_s32[1], diff_hi[1], diff_hi[1]); - i += 2; - } while (i < h); + i -= 2; + } while (i != 0); sse_s32[0] = vaddq_s32(sse_s32[0], sse_s32[1]); @@ -525,7 +521,7 @@ int32x4_t sse_s32[4] = { vdupq_n_s32(0), vdupq_n_s32(0), vdupq_n_s32(0), vdupq_n_s32(0) }; - int i = 0; + int i = h; do { s[0] = vld1q_u8(src); src += src_stride; @@ -561,8 +557,8 @@ sse_s32[2] = vmlal_s16(sse_s32[2], diff_hi[2], diff_hi[2]); sse_s32[3] = vmlal_s16(sse_s32[3], diff_hi[3], diff_hi[3]); - i += 2; - } while (i < h); + i -= 2; + } while (i != 0); sse_s32[0] = vaddq_s32(sse_s32[0], sse_s32[1]); sse_s32[2] = vaddq_s32(sse_s32[2], sse_s32[3]); @@ -647,7 +643,7 @@ int h) { uint64x2_t square_result = vdupq_n_u64(0); uint32_t d0, d1; - int i = 0; + int i = h; uint8_t *dst_ptr = dst; uint16_t *src_ptr = src; do { @@ -678,8 +674,8 @@ const uint16x8_t src_16x8 = vcombine_u16(src0_16x4, src1_16x4); COMPUTE_MSE_16BIT(src_16x8, dst_16x8) - i += 2; - } while (i < h); + i -= 2; + } while (i != 0); uint64x1_t sum = vadd_u64(vget_high_u64(square_result), vget_low_u64(square_result)); return vget_lane_u64(sum, 0); @@ -689,7 +685,7 @@ uint16_t *src, int sstride, int h) { uint64x2_t square_result = vdupq_n_u64(0); - int i = 0; + int i = h; do { // d7 d6 d5 d4 d3 d2 d1 d0 - 8 bit const uint16x8_t dst_16x8 = vmovl_u8(vld1_u8(&dst[i * dstride])); @@ -697,8 +693,7 @@ const uint16x8_t src_16x8 = vld1q_u16(&src[i * sstride]); COMPUTE_MSE_16BIT(src_16x8, dst_16x8) - i++; - } while (i < h); + } while (--i != 0); uint64x1_t sum = vadd_u64(vget_high_u64(square_result), vget_low_u64(square_result)); return vget_lane_u64(sum, 0);