Count down in Neon variance loops
Counting down and terminating on a zero comparison allows us
to use flag-setting arithmetic instructions, avoiding
an additional CMP instruction before the conditional branch.
This doesn't really affect performance but code size is reduced
by a useful amount since loop prologues are also shorter -
especially with older compilers.
Change-Id: I3f75319654c13006591ce861f5a0c2d3230a5e32
diff --git a/aom_dsp/arm/variance_neon.c b/aom_dsp/arm/variance_neon.c
index 40e40f0..e2bc96b 100644
--- a/aom_dsp/arm/variance_neon.c
+++ b/aom_dsp/arm/variance_neon.c
@@ -27,7 +27,7 @@
uint32x4_t ref_sum = vdupq_n_u32(0);
uint32x4_t sse_u32 = vdupq_n_u32(0);
- int i = 0;
+ int i = h;
do {
uint8x16_t s = load_unaligned_u8q(src, src_stride);
uint8x16_t r = load_unaligned_u8q(ref, ref_stride);
@@ -40,8 +40,8 @@
src += 4 * src_stride;
ref += 4 * ref_stride;
- i += 4;
- } while (i < h);
+ i -= 4;
+ } while (i != 0);
int32x4_t sum_diff =
vsubq_s32(vreinterpretq_s32_u32(src_sum), vreinterpretq_s32_u32(ref_sum));
@@ -56,7 +56,7 @@
uint32x4_t ref_sum = vdupq_n_u32(0);
uint32x4_t sse_u32 = vdupq_n_u32(0);
- int i = 0;
+ int i = h;
do {
uint8x16_t s = vcombine_u8(vld1_u8(src), vld1_u8(src + src_stride));
uint8x16_t r = vcombine_u8(vld1_u8(ref), vld1_u8(ref + ref_stride));
@@ -69,8 +69,8 @@
src += 2 * src_stride;
ref += 2 * ref_stride;
- i += 2;
- } while (i < h);
+ i -= 2;
+ } while (i != 0);
int32x4_t sum_diff =
vsubq_s32(vreinterpretq_s32_u32(src_sum), vreinterpretq_s32_u32(ref_sum));
@@ -85,7 +85,7 @@
uint32x4_t ref_sum = vdupq_n_u32(0);
uint32x4_t sse_u32 = vdupq_n_u32(0);
- int i = 0;
+ int i = h;
do {
uint8x16_t s = vld1q_u8(src);
uint8x16_t r = vld1q_u8(ref);
@@ -98,8 +98,7 @@
src += src_stride;
ref += ref_stride;
- i++;
- } while (i < h);
+ } while (--i != 0);
int32x4_t sum_diff =
vsubq_s32(vreinterpretq_s32_u32(src_sum), vreinterpretq_s32_u32(ref_sum));
@@ -114,7 +113,7 @@
uint32x4_t ref_sum = vdupq_n_u32(0);
uint32x4_t sse_u32 = vdupq_n_u32(0);
- int i = 0;
+ int i = h;
do {
int j = 0;
do {
@@ -132,8 +131,7 @@
src += src_stride;
ref += ref_stride;
- i++;
- } while (i < h);
+ } while (--i != 0);
int32x4_t sum_diff =
vsubq_s32(vreinterpretq_s32_u32(src_sum), vreinterpretq_s32_u32(ref_sum));
@@ -171,7 +169,7 @@
// 32767 / 255 ~= 128, but we use an 8-wide accumulator; so 256 4-wide rows.
assert(h <= 256);
- int i = 0;
+ int i = h;
do {
uint8x8_t s = load_unaligned_u8(src, src_stride);
uint8x8_t r = load_unaligned_u8(ref, ref_stride);
@@ -184,8 +182,8 @@
src += 2 * src_stride;
ref += 2 * ref_stride;
- i += 2;
- } while (i < h);
+ i -= 2;
+ } while (i != 0);
*sum = horizontal_add_s16x8(sum_s16);
*sse = (uint32_t)horizontal_add_s32x4(sse_s32);
@@ -201,7 +199,7 @@
// 32767 / 255 ~= 128
assert(h <= 128);
- int i = 0;
+ int i = h;
do {
uint8x8_t s = vld1_u8(src);
uint8x8_t r = vld1_u8(ref);
@@ -215,8 +213,7 @@
src += src_stride;
ref += ref_stride;
- i++;
- } while (i < h);
+ } while (--i != 0);
*sum = horizontal_add_s16x8(sum_s16);
*sse = (uint32_t)horizontal_add_s32x4(vaddq_s32(sse_s32[0], sse_s32[1]));
@@ -232,7 +229,7 @@
// 32767 / 255 ~= 128, so 128 16-wide rows.
assert(h <= 128);
- int i = 0;
+ int i = h;
do {
uint8x16_t s = vld1q_u8(src);
uint8x16_t r = vld1q_u8(ref);
@@ -256,8 +253,7 @@
src += src_stride;
ref += ref_stride;
- i++;
- } while (i < h);
+ } while (--i != 0);
*sum = horizontal_add_s16x8(vaddq_s16(sum_s16[0], sum_s16[1]));
*sse = (uint32_t)horizontal_add_s32x4(vaddq_s32(sse_s32[0], sse_s32[1]));
@@ -416,7 +412,7 @@
unsigned int *sse, int h) {
uint32x4_t sse_u32 = vdupq_n_u32(0);
- int i = 0;
+ int i = h;
do {
uint8x16_t s = vcombine_u8(vld1_u8(src), vld1_u8(src + src_stride));
uint8x16_t r = vcombine_u8(vld1_u8(ref), vld1_u8(ref + ref_stride));
@@ -427,8 +423,8 @@
src += 2 * src_stride;
ref += 2 * ref_stride;
- i += 2;
- } while (i < h);
+ i -= 2;
+ } while (i != 0);
*sse = horizontal_add_u32x4(sse_u32);
return horizontal_add_u32x4(sse_u32);
@@ -439,7 +435,7 @@
unsigned int *sse, int h) {
uint32x4_t sse_u32[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
- int i = 0;
+ int i = h;
do {
uint8x16_t s0 = vld1q_u8(src);
uint8x16_t s1 = vld1q_u8(src + src_stride);
@@ -454,8 +450,8 @@
src += 2 * src_stride;
ref += 2 * ref_stride;
- i += 2;
- } while (i < h);
+ i -= 2;
+ } while (i != 0);
*sse = horizontal_add_u32x4(vaddq_u32(sse_u32[0], sse_u32[1]));
return horizontal_add_u32x4(vaddq_u32(sse_u32[0], sse_u32[1]));
@@ -483,7 +479,7 @@
uint16x8_t diff[2];
int32x4_t sse_s32[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
- int i = 0;
+ int i = h;
do {
s[0] = vld1_u8(src);
src += src_stride;
@@ -507,8 +503,8 @@
sse_s32[0] = vmlal_s16(sse_s32[0], diff_hi[0], diff_hi[0]);
sse_s32[1] = vmlal_s16(sse_s32[1], diff_hi[1], diff_hi[1]);
- i += 2;
- } while (i < h);
+ i -= 2;
+ } while (i != 0);
sse_s32[0] = vaddq_s32(sse_s32[0], sse_s32[1]);
@@ -525,7 +521,7 @@
int32x4_t sse_s32[4] = { vdupq_n_s32(0), vdupq_n_s32(0), vdupq_n_s32(0),
vdupq_n_s32(0) };
- int i = 0;
+ int i = h;
do {
s[0] = vld1q_u8(src);
src += src_stride;
@@ -561,8 +557,8 @@
sse_s32[2] = vmlal_s16(sse_s32[2], diff_hi[2], diff_hi[2]);
sse_s32[3] = vmlal_s16(sse_s32[3], diff_hi[3], diff_hi[3]);
- i += 2;
- } while (i < h);
+ i -= 2;
+ } while (i != 0);
sse_s32[0] = vaddq_s32(sse_s32[0], sse_s32[1]);
sse_s32[2] = vaddq_s32(sse_s32[2], sse_s32[3]);
@@ -647,7 +643,7 @@
int h) {
uint64x2_t square_result = vdupq_n_u64(0);
uint32_t d0, d1;
- int i = 0;
+ int i = h;
uint8_t *dst_ptr = dst;
uint16_t *src_ptr = src;
do {
@@ -678,8 +674,8 @@
const uint16x8_t src_16x8 = vcombine_u16(src0_16x4, src1_16x4);
COMPUTE_MSE_16BIT(src_16x8, dst_16x8)
- i += 2;
- } while (i < h);
+ i -= 2;
+ } while (i != 0);
uint64x1_t sum =
vadd_u64(vget_high_u64(square_result), vget_low_u64(square_result));
return vget_lane_u64(sum, 0);
@@ -689,7 +685,7 @@
uint16_t *src, int sstride,
int h) {
uint64x2_t square_result = vdupq_n_u64(0);
- int i = 0;
+ int i = h;
do {
// d7 d6 d5 d4 d3 d2 d1 d0 - 8 bit
const uint16x8_t dst_16x8 = vmovl_u8(vld1_u8(&dst[i * dstride]));
@@ -697,8 +693,7 @@
const uint16x8_t src_16x8 = vld1q_u16(&src[i * sstride]);
COMPUTE_MSE_16BIT(src_16x8, dst_16x8)
- i++;
- } while (i < h);
+ } while (--i != 0);
uint64x1_t sum =
vadd_u64(vget_high_u64(square_result), vget_low_u64(square_result));
return vget_lane_u64(sum, 0);