Optimize Neon HBD sad_avg functions
Optimize aom_highbd_sad_avg_neon functions by accumulating into 16-bit
vectors and widening only at the point of overflow.
Change-Id: I7ec960ca1912233f813515718c253d3bc3f0f79d
diff --git a/aom_dsp/arm/highbd_sad_neon.c b/aom_dsp/arm/highbd_sad_neon.c
index 0487a32..cb732cb 100644
--- a/aom_dsp/arm/highbd_sad_neon.c
+++ b/aom_dsp/arm/highbd_sad_neon.c
@@ -256,24 +256,40 @@
const uint16_t *src16_ptr = CONVERT_TO_SHORTPTR(src_ptr);
const uint16_t *ref16_ptr = CONVERT_TO_SHORTPTR(ref_ptr);
const uint16_t *pred16_ptr = CONVERT_TO_SHORTPTR(second_pred);
- uint32x4_t sum = vdupq_n_u32(0);
- int i = h;
+ // 'h_overflow' is the number of 8-wide rows we can process before 16-bit
+ // accumulators overflow. After hitting this limit accumulate into 32-bit
+ // elements. 65535 / 4095 ~= 16, so 16 8-wide rows.
+ const int h_overflow = 16;
+ // If block height 'h' is smaller than this limit, use 'h' instead.
+ const int h_limit = h < h_overflow ? h : h_overflow;
+ assert(h % h_limit == 0);
+
+ uint32x4_t sum_u32 = vdupq_n_u32(0);
+
do {
- uint16x8_t s = vld1q_u16(src16_ptr);
- uint16x8_t r = vld1q_u16(ref16_ptr);
- uint16x8_t p = vld1q_u16(pred16_ptr);
+ uint16x8_t sum_u16 = vdupq_n_u16(0);
- uint16x8_t avg = vrhaddq_u16(r, p);
- uint16x8_t diff = vabdq_u16(s, avg);
- sum = vpadalq_u16(sum, diff);
+ int i = h_limit;
+ do {
+ uint16x8_t s = vld1q_u16(src16_ptr);
+ uint16x8_t r = vld1q_u16(ref16_ptr);
+ uint16x8_t p = vld1q_u16(pred16_ptr);
- src16_ptr += src_stride;
- ref16_ptr += ref_stride;
- pred16_ptr += 8;
- } while (--i != 0);
+ uint16x8_t avg = vrhaddq_u16(r, p);
+ sum_u16 = vabaq_u16(sum_u16, s, avg);
- return horizontal_add_u32x4(sum);
+ src16_ptr += src_stride;
+ ref16_ptr += ref_stride;
+ pred16_ptr += 8;
+ } while (--i != 0);
+
+ sum_u32 = vpadalq_u16(sum_u32, sum_u16);
+
+ h -= h_limit;
+ } while (h != 0);
+
+ return horizontal_add_u32x4(sum_u32);
}
static inline uint32_t highbd_sad16xh_avg_neon(const uint8_t *src_ptr,
@@ -284,120 +300,157 @@
const uint16_t *src16_ptr = CONVERT_TO_SHORTPTR(src_ptr);
const uint16_t *ref16_ptr = CONVERT_TO_SHORTPTR(ref_ptr);
const uint16_t *pred16_ptr = CONVERT_TO_SHORTPTR(second_pred);
- uint32x4_t sum[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
- int i = h;
+ // 'h_overflow' is the number of 16-wide rows we can process before 16-bit
+ // accumulators overflow. After hitting this limit accumulate into 32-bit
+ // elements. 65535 / 4095 ~= 16, so 16 16-wide rows using two accumulators.
+ const int h_overflow = 16;
+ // If block height 'h' is smaller than this limit, use 'h' instead.
+ const int h_limit = h < h_overflow ? h : h_overflow;
+ assert(h % h_limit == 0);
+
+ uint32x4_t sum_u32 = vdupq_n_u32(0);
+
do {
- uint16x8_t s0, s1, r0, r1, p0, p1;
- uint16x8_t avg0, avg1, diff0, diff1;
+ uint16x8_t sum_u16[2] = { vdupq_n_u16(0), vdupq_n_u16(0) };
- s0 = vld1q_u16(src16_ptr);
- r0 = vld1q_u16(ref16_ptr);
- p0 = vld1q_u16(pred16_ptr);
- avg0 = vrhaddq_u16(r0, p0);
- diff0 = vabdq_u16(s0, avg0);
- sum[0] = vpadalq_u16(sum[0], diff0);
+ int i = h_limit;
+ do {
+ uint16x8_t s0 = vld1q_u16(src16_ptr);
+ uint16x8_t r0 = vld1q_u16(ref16_ptr);
+ uint16x8_t p0 = vld1q_u16(pred16_ptr);
- s1 = vld1q_u16(src16_ptr + 8);
- r1 = vld1q_u16(ref16_ptr + 8);
- p1 = vld1q_u16(pred16_ptr + 8);
- avg1 = vrhaddq_u16(r1, p1);
- diff1 = vabdq_u16(s1, avg1);
- sum[1] = vpadalq_u16(sum[1], diff1);
+ uint16x8_t avg0 = vrhaddq_u16(r0, p0);
+ sum_u16[0] = vabaq_u16(sum_u16[0], s0, avg0);
- src16_ptr += src_stride;
- ref16_ptr += ref_stride;
- pred16_ptr += 16;
- } while (--i != 0);
+ uint16x8_t s1 = vld1q_u16(src16_ptr + 8);
+ uint16x8_t r1 = vld1q_u16(ref16_ptr + 8);
+ uint16x8_t p1 = vld1q_u16(pred16_ptr + 8);
- sum[0] = vaddq_u32(sum[0], sum[1]);
- return horizontal_add_u32x4(sum[0]);
+ uint16x8_t avg1 = vrhaddq_u16(r1, p1);
+ sum_u16[1] = vabaq_u16(sum_u16[1], s1, avg1);
+
+ src16_ptr += src_stride;
+ ref16_ptr += ref_stride;
+ pred16_ptr += 16;
+ } while (--i != 0);
+
+ sum_u32 = vpadalq_u16(sum_u32, sum_u16[0]);
+ sum_u32 = vpadalq_u16(sum_u32, sum_u16[1]);
+
+ h -= h_limit;
+ } while (h != 0);
+
+ return horizontal_add_u32x4(sum_u32);
}
static inline uint32_t highbd_sadwxh_avg_neon(const uint8_t *src_ptr,
int src_stride,
const uint8_t *ref_ptr,
- int ref_stride, int w, int h,
- const uint8_t *second_pred) {
+ int ref_stride,
+ const uint8_t *second_pred, int w,
+ int h, const int h_overflow) {
const uint16_t *src16_ptr = CONVERT_TO_SHORTPTR(src_ptr);
const uint16_t *ref16_ptr = CONVERT_TO_SHORTPTR(ref_ptr);
const uint16_t *pred16_ptr = CONVERT_TO_SHORTPTR(second_pred);
- uint32x4_t sum[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
- vdupq_n_u32(0) };
- int i = h;
+ const int h_limit = h < h_overflow ? h : h_overflow;
+ assert(h % h_limit == 0);
+
+ uint32x4_t sum_u32 = vdupq_n_u32(0);
+
do {
- int j = 0;
+ uint16x8_t sum_u16[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
+ vdupq_n_u16(0) };
+
+ int i = h_limit;
do {
- uint16x8_t s0, s1, s2, s3, r0, r1, r2, r3, p0, p1, p2, p3;
- uint16x8_t avg0, avg1, avg2, avg3, diff0, diff1, diff2, diff3;
+ int j = 0;
+ do {
+ uint16x8_t s0 = vld1q_u16(src16_ptr + j);
+ uint16x8_t r0 = vld1q_u16(ref16_ptr + j);
+ uint16x8_t p0 = vld1q_u16(pred16_ptr + j);
- s0 = vld1q_u16(src16_ptr + j);
- r0 = vld1q_u16(ref16_ptr + j);
- p0 = vld1q_u16(pred16_ptr + j);
- avg0 = vrhaddq_u16(r0, p0);
- diff0 = vabdq_u16(s0, avg0);
- sum[0] = vpadalq_u16(sum[0], diff0);
+ uint16x8_t avg0 = vrhaddq_u16(r0, p0);
+ sum_u16[0] = vabaq_u16(sum_u16[0], s0, avg0);
- s1 = vld1q_u16(src16_ptr + j + 8);
- r1 = vld1q_u16(ref16_ptr + j + 8);
- p1 = vld1q_u16(pred16_ptr + j + 8);
- avg1 = vrhaddq_u16(r1, p1);
- diff1 = vabdq_u16(s1, avg1);
- sum[1] = vpadalq_u16(sum[1], diff1);
+ uint16x8_t s1 = vld1q_u16(src16_ptr + j + 8);
+ uint16x8_t r1 = vld1q_u16(ref16_ptr + j + 8);
+ uint16x8_t p1 = vld1q_u16(pred16_ptr + j + 8);
- s2 = vld1q_u16(src16_ptr + j + 16);
- r2 = vld1q_u16(ref16_ptr + j + 16);
- p2 = vld1q_u16(pred16_ptr + j + 16);
- avg2 = vrhaddq_u16(r2, p2);
- diff2 = vabdq_u16(s2, avg2);
- sum[2] = vpadalq_u16(sum[2], diff2);
+ uint16x8_t avg1 = vrhaddq_u16(r1, p1);
+ sum_u16[1] = vabaq_u16(sum_u16[1], s1, avg1);
- s3 = vld1q_u16(src16_ptr + j + 24);
- r3 = vld1q_u16(ref16_ptr + j + 24);
- p3 = vld1q_u16(pred16_ptr + j + 24);
- avg3 = vrhaddq_u16(r3, p3);
- diff3 = vabdq_u16(s3, avg3);
- sum[3] = vpadalq_u16(sum[3], diff3);
+ uint16x8_t s2 = vld1q_u16(src16_ptr + j + 16);
+ uint16x8_t r2 = vld1q_u16(ref16_ptr + j + 16);
+ uint16x8_t p2 = vld1q_u16(pred16_ptr + j + 16);
- j += 32;
- } while (j < w);
+ uint16x8_t avg2 = vrhaddq_u16(r2, p2);
+ sum_u16[2] = vabaq_u16(sum_u16[2], s2, avg2);
- src16_ptr += src_stride;
- ref16_ptr += ref_stride;
- pred16_ptr += w;
- } while (--i != 0);
+ uint16x8_t s3 = vld1q_u16(src16_ptr + j + 24);
+ uint16x8_t r3 = vld1q_u16(ref16_ptr + j + 24);
+ uint16x8_t p3 = vld1q_u16(pred16_ptr + j + 24);
- sum[0] = vaddq_u32(sum[0], sum[1]);
- sum[2] = vaddq_u32(sum[2], sum[3]);
- sum[0] = vaddq_u32(sum[0], sum[2]);
+ uint16x8_t avg3 = vrhaddq_u16(r3, p3);
+ sum_u16[3] = vabaq_u16(sum_u16[3], s3, avg3);
- return horizontal_add_u32x4(sum[0]);
+ j += 32;
+ } while (j < w);
+
+ src16_ptr += src_stride;
+ ref16_ptr += ref_stride;
+ pred16_ptr += w;
+ } while (--i != 0);
+
+ sum_u32 = vpadalq_u16(sum_u32, sum_u16[0]);
+ sum_u32 = vpadalq_u16(sum_u32, sum_u16[1]);
+ sum_u32 = vpadalq_u16(sum_u32, sum_u16[2]);
+ sum_u32 = vpadalq_u16(sum_u32, sum_u16[3]);
+
+ h -= h_limit;
+ } while (h != 0);
+
+ return horizontal_add_u32x4(sum_u32);
}
-static inline unsigned int highbd_sad128xh_avg_neon(
- const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,
- int ref_stride, int h, const uint8_t *second_pred) {
- return highbd_sadwxh_avg_neon(src_ptr, src_stride, ref_ptr, ref_stride, 128,
- h, second_pred);
+static inline uint32_t highbd_sad32xh_avg_neon(const uint8_t *src_ptr,
+ int src_stride,
+ const uint8_t *ref_ptr,
+ int ref_stride, int h,
+ const uint8_t *second_pred) {
+ // 'h_overflow' is the number of 32-wide rows we can process before 16-bit
+ // accumulators overflow. After hitting this limit accumulate into 32-bit
+ // elements. 65535 / 4095 ~= 16, so 16 32-wide rows using four accumulators.
+ const int h_overflow = 16;
+ return highbd_sadwxh_avg_neon(src_ptr, src_stride, ref_ptr, ref_stride,
+ second_pred, 32, h, h_overflow);
}
-static inline unsigned int highbd_sad64xh_avg_neon(const uint8_t *src_ptr,
- int src_stride,
- const uint8_t *ref_ptr,
- int ref_stride, int h,
- const uint8_t *second_pred) {
- return highbd_sadwxh_avg_neon(src_ptr, src_stride, ref_ptr, ref_stride, 64, h,
- second_pred);
+static inline uint32_t highbd_sad64xh_avg_neon(const uint8_t *src_ptr,
+ int src_stride,
+ const uint8_t *ref_ptr,
+ int ref_stride, int h,
+ const uint8_t *second_pred) {
+ // 'h_overflow' is the number of 64-wide rows we can process before 16-bit
+ // accumulators overflow. After hitting this limit accumulate into 32-bit
+ // elements. 65535 / 4095 ~= 16, so 8 64-wide rows using four accumulators.
+ const int h_overflow = 8;
+ return highbd_sadwxh_avg_neon(src_ptr, src_stride, ref_ptr, ref_stride,
+ second_pred, 64, h, h_overflow);
}
-static inline unsigned int highbd_sad32xh_avg_neon(const uint8_t *src_ptr,
- int src_stride,
- const uint8_t *ref_ptr,
- int ref_stride, int h,
- const uint8_t *second_pred) {
- return highbd_sadwxh_avg_neon(src_ptr, src_stride, ref_ptr, ref_stride, 32, h,
- second_pred);
+static inline uint32_t highbd_sad128xh_avg_neon(const uint8_t *src_ptr,
+ int src_stride,
+ const uint8_t *ref_ptr,
+ int ref_stride, int h,
+ const uint8_t *second_pred) {
+ // 'h_overflow' is the number of 128-wide rows we can process before 16-bit
+ // accumulators overflow. After hitting this limit accumulate into 32-bit
+ // elements. 65535 / 4095 ~= 16, so 4 128-wide rows using four accumulators.
+ const int h_overflow = 4;
+ return highbd_sadwxh_avg_neon(src_ptr, src_stride, ref_ptr, ref_stride,
+ second_pred, 128, h, h_overflow);
}
#define HBD_SAD_WXH_AVG_NEON(w, h) \
@@ -435,3 +488,5 @@
HBD_SAD_WXH_AVG_NEON(64, 16)
#endif // !CONFIG_REALTIME_ONLY
+
+#undef HBD_SAD_WXH_AVG_NEON