Add High bit-depth case
Optimized final SSE addition
Incorporated changes suggested
Change-Id: I8984e562873ac2ce39f3bfd6b9830b6423124116
diff --git a/aom_dsp/aom_dsp_rtcd_defs.pl b/aom_dsp/aom_dsp_rtcd_defs.pl
index 382347f..a6c09e2 100755
--- a/aom_dsp/aom_dsp_rtcd_defs.pl
+++ b/aom_dsp/aom_dsp_rtcd_defs.pl
@@ -608,7 +608,7 @@
specialize qw/aom_highbd_subtract_block sse2/;
add_proto qw/int64_t/, "aom_highbd_sse", "const uint8_t *a8, int a_stride, const uint8_t *b8,int b_stride, int width, int height";
- specialize qw/aom_highbd_sse sse4_1 avx2/;
+ specialize qw/aom_highbd_sse sse4_1 avx2 neon/;
}
if (aom_config("CONFIG_AV1_ENCODER") eq "yes") {
diff --git a/aom_dsp/arm/sse_neon.c b/aom_dsp/arm/sse_neon.c
index 6f61b91..06b81cc 100644
--- a/aom_dsp/arm/sse_neon.c
+++ b/aom_dsp/arm/sse_neon.c
@@ -14,50 +14,474 @@
#include "aom/aom_integer.h"
+static INLINE uint32_t sse_W16x1_neon(uint8x16_t q2, uint8x16_t q3) {
+ const uint16_t sse1 = 0;
+ const uint16x8_t q1 = vld1q_dup_u16(&sse1);
+
+ uint32_t sse;
+
+ uint8x16_t q4 = vabdq_u8(q2, q3); // diff = abs(a[x] - b[x])
+ uint8x8_t d0 = vget_low_u8(q4);
+ uint8x8_t d1 = vget_high_u8(q4);
+
+ uint16x8_t q6 = vmlal_u8(q1, d0, d0);
+ uint16x8_t q7 = vmlal_u8(q1, d1, d1);
+
+ uint32x4_t q8 = vaddl_u16(vget_low_u16(q6), vget_high_u16(q6));
+ uint32x4_t q9 = vaddl_u16(vget_low_u16(q7), vget_high_u16(q7));
+
+ uint32x2_t d4 = vadd_u32(vget_low_u32(q8), vget_high_u32(q8));
+ uint32x2_t d5 = vadd_u32(vget_low_u32(q9), vget_high_u32(q9));
+
+ uint32x2_t d6 = vadd_u32(d4, d5);
+
+ sse = vget_lane_u32(d6, 0);
+ sse += vget_lane_u32(d6, 1);
+
+ return sse;
+}
+
int64_t aom_sse_neon(const uint8_t *a, int a_stride, const uint8_t *b,
int b_stride, int width, int height) {
- int addinc;
- uint8x8_t d0, d1;
+ const uint8x16_t q0 = {
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
+ };
+ int addinc, x, y;
+ uint8x8_t d0, d1, d2, d3;
uint8_t dx;
- uint32x2_t d2, d3;
- uint8x16_t q0 = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 };
- uint32x4_t q8, q9;
- uint16x8_t q1, q6, q7;
uint8x16_t q2, q3, q4, q5;
uint32_t sse = 0;
- const uint16_t sse1 = 0;
- q1 = vld1q_dup_u16(&sse1);
- for (int y = 0; y < height; y++) {
- int x = width;
- while (x > 0) {
- addinc = width - x;
- q2 = vld1q_u8(a + addinc);
- q3 = vld1q_u8(b + addinc);
- if (x < 16) {
- dx = x;
- q4 = vld1q_dup_u8(&dx);
- q5 = vcltq_u8(q0, q4);
- q2 = vandq_u8(q2, q5);
- q3 = vandq_u8(q3, q5);
- }
- q4 = vabdq_u8(q2, q3); // diff = abs(a[x] - b[x])
- d0 = vget_low_u8(q4);
- d1 = vget_high_u8(q4);
- q6 = vmlal_u8(q1, d0, d0);
- q7 = vmlal_u8(q1, d1, d1);
- q8 = vaddl_u16(vget_low_u16(q6), vget_high_u16(q6));
- q9 = vaddl_u16(vget_low_u16(q7), vget_high_u16(q7));
+ uint8x8x2_t tmp, tmp2;
- d2 = vadd_u32(vget_low_u32(q8), vget_high_u32(q8));
- d3 = vadd_u32(vget_low_u32(q9), vget_high_u32(q9));
- sse += vget_lane_u32(d2, 0);
- sse += vget_lane_u32(d2, 1);
- sse += vget_lane_u32(d3, 0);
- sse += vget_lane_u32(d3, 1);
- x -= 16;
- }
- a += a_stride;
- b += b_stride;
+ switch (width) {
+ case 4:
+ for (y = 0; y < height; y += 4) {
+ d0 = vld1_u8(a); // load 4 data
+ a += a_stride;
+ d1 = vld1_u8(a);
+ a += a_stride;
+ d2 = vld1_u8(a);
+ a += a_stride;
+ d3 = vld1_u8(a);
+ a += a_stride;
+ tmp = vzip_u8(d0, d1);
+ tmp2 = vzip_u8(d2, d3);
+ q2 = vcombine_u8(tmp.val[0], tmp2.val[0]); // make a 16 data vector
+
+ d0 = vld1_u8(b);
+ b += b_stride;
+ d1 = vld1_u8(b);
+ b += b_stride;
+ d2 = vld1_u8(b);
+ b += b_stride;
+ d3 = vld1_u8(b);
+ b += b_stride;
+ tmp = vzip_u8(d0, d1);
+ tmp2 = vzip_u8(d2, d3);
+ q3 = vcombine_u8(tmp.val[0], tmp2.val[0]);
+
+ sse += sse_W16x1_neon(q2, q3);
+ }
+ break;
+ case 8:
+ for (y = 0; y < height; y += 2) {
+ d0 = vld1_u8(a); // load 8 data
+ d1 = vld1_u8(a + a_stride);
+ q2 = vcombine_u8(d0, d1); // make a 16 data vector
+
+ d0 = vld1_u8(b);
+ d1 = vld1_u8(b + b_stride);
+ q3 = vcombine_u8(d0, d1);
+
+ sse += sse_W16x1_neon(q2, q3);
+
+ a += 2 * a_stride;
+ b += 2 * b_stride;
+ }
+ break;
+ case 16:
+ for (y = 0; y < height; y++) {
+ q2 = vld1q_u8(a);
+ q3 = vld1q_u8(b);
+
+ sse += sse_W16x1_neon(q2, q3);
+
+ a += a_stride;
+ b += b_stride;
+ }
+ break;
+ case 32:
+ for (y = 0; y < height; y++) {
+ q2 = vld1q_u8(a);
+ q3 = vld1q_u8(b);
+
+ sse += sse_W16x1_neon(q2, q3);
+
+ q2 = vld1q_u8(a + 16);
+ q3 = vld1q_u8(b + 16);
+
+ sse += sse_W16x1_neon(q2, q3);
+
+ a += a_stride;
+ b += b_stride;
+ }
+ break;
+ case 64:
+ for (y = 0; y < height; y++) {
+ q2 = vld1q_u8(a);
+ q3 = vld1q_u8(b);
+
+ sse += sse_W16x1_neon(q2, q3);
+
+ q2 = vld1q_u8(a + 16);
+ q3 = vld1q_u8(b + 16);
+
+ sse += sse_W16x1_neon(q2, q3);
+
+ q2 = vld1q_u8(a + 32);
+ q3 = vld1q_u8(b + 32);
+
+ sse += sse_W16x1_neon(q2, q3);
+
+ q2 = vld1q_u8(a + 48);
+ q3 = vld1q_u8(b + 48);
+
+ sse += sse_W16x1_neon(q2, q3);
+
+ a += a_stride;
+ b += b_stride;
+ }
+ break;
+ case 128:
+ for (y = 0; y < height; y++) {
+ q2 = vld1q_u8(a);
+ q3 = vld1q_u8(b);
+
+ sse += sse_W16x1_neon(q2, q3);
+
+ q2 = vld1q_u8(a + 16);
+ q3 = vld1q_u8(b + 16);
+
+ sse += sse_W16x1_neon(q2, q3);
+
+ q2 = vld1q_u8(a + 32);
+ q3 = vld1q_u8(b + 32);
+
+ sse += sse_W16x1_neon(q2, q3);
+
+ q2 = vld1q_u8(a + 48);
+ q3 = vld1q_u8(b + 48);
+
+ sse += sse_W16x1_neon(q2, q3);
+
+ q2 = vld1q_u8(a + 64);
+ q3 = vld1q_u8(b + 64);
+
+ sse += sse_W16x1_neon(q2, q3);
+
+ q2 = vld1q_u8(a + 80);
+ q3 = vld1q_u8(b + 80);
+
+ sse += sse_W16x1_neon(q2, q3);
+
+ q2 = vld1q_u8(a + 96);
+ q3 = vld1q_u8(b + 96);
+
+ sse += sse_W16x1_neon(q2, q3);
+
+ q2 = vld1q_u8(a + 112);
+ q3 = vld1q_u8(b + 112);
+
+ sse += sse_W16x1_neon(q2, q3);
+
+ a += a_stride;
+ b += b_stride;
+ }
+ break;
+ default:
+ for (y = 0; y < height; y++) {
+ x = width;
+ while (x > 0) {
+ addinc = width - x;
+ q2 = vld1q_u8(a + addinc);
+ q3 = vld1q_u8(b + addinc);
+ if (x < 16) {
+ dx = x;
+ q4 = vld1q_dup_u8(&dx);
+ q5 = vcltq_u8(q0, q4);
+ q2 = vandq_u8(q2, q5);
+ q3 = vandq_u8(q3, q5);
+ }
+ sse += sse_W16x1_neon(q2, q3);
+ x -= 16;
+ }
+ a += a_stride;
+ b += b_stride;
+ }
}
return (int64_t)sse;
}
+
+#if CONFIG_AV1_HIGHBITDEPTH
+static INLINE uint32_t highbd_sse_W8x1_neon(uint16x8_t q2, uint16x8_t q3) {
+ uint32_t sse;
+ const uint32_t sse1 = 0;
+ const uint32x4_t q1 = vld1q_dup_u32(&sse1);
+
+ uint16x8_t q4 = vabdq_u16(q2, q3); // diff = abs(a[x] - b[x])
+ uint16x4_t d0 = vget_low_u16(q4);
+ uint16x4_t d1 = vget_high_u16(q4);
+
+ uint32x4_t q6 = vmlal_u16(q1, d0, d0);
+ uint32x4_t q7 = vmlal_u16(q1, d1, d1);
+
+ uint32x2_t d4 = vadd_u32(vget_low_u32(q6), vget_high_u32(q6));
+ uint32x2_t d5 = vadd_u32(vget_low_u32(q7), vget_high_u32(q7));
+
+ uint32x2_t d6 = vadd_u32(d4, d5);
+
+ sse = vget_lane_u32(d6, 0);
+ sse += vget_lane_u32(d6, 1);
+
+ return sse;
+}
+
+int64_t aom_highbd_sse_neon(const uint8_t *a8, int a_stride, const uint8_t *b8,
+ int b_stride, int width, int height) {
+ const uint16x8_t q0 = { 0, 1, 2, 3, 4, 5, 6, 7 };
+ int64_t sse = 0;
+ uint16_t *a = CONVERT_TO_SHORTPTR(a8);
+ uint16_t *b = CONVERT_TO_SHORTPTR(b8);
+ int x, y;
+ int addinc;
+ uint16x4_t d0, d1, d2, d3;
+ uint16_t dx;
+ uint16x8_t q2, q3, q4, q5;
+
+ switch (width) {
+ case 4:
+ for (y = 0; y < height; y += 2) {
+ d0 = vld1_u16(a); // load 4 data
+ a += a_stride;
+ d1 = vld1_u16(a);
+ a += a_stride;
+
+ d2 = vld1_u16(b);
+ b += b_stride;
+ d3 = vld1_u16(b);
+ b += b_stride;
+ q2 = vcombine_u16(d0, d1); // make a 8 data vector
+ q3 = vcombine_u16(d2, d3);
+
+ sse += highbd_sse_W8x1_neon(q2, q3);
+ }
+ break;
+ case 8:
+ for (y = 0; y < height; y++) {
+ q2 = vld1q_u16(a);
+ q3 = vld1q_u16(b);
+
+ sse += highbd_sse_W8x1_neon(q2, q3);
+
+ a += a_stride;
+ b += b_stride;
+ }
+ break;
+ case 16:
+ for (y = 0; y < height; y++) {
+ q2 = vld1q_u16(a);
+ q3 = vld1q_u16(b);
+
+ sse += highbd_sse_W8x1_neon(q2, q3);
+
+ q2 = vld1q_u16(a + 8);
+ q3 = vld1q_u16(b + 8);
+
+ sse += highbd_sse_W8x1_neon(q2, q3);
+
+ a += a_stride;
+ b += b_stride;
+ }
+ break;
+ case 32:
+ for (y = 0; y < height; y++) {
+ q2 = vld1q_u16(a);
+ q3 = vld1q_u16(b);
+
+ sse += highbd_sse_W8x1_neon(q2, q3);
+
+ q2 = vld1q_u16(a + 8);
+ q3 = vld1q_u16(b + 8);
+
+ sse += highbd_sse_W8x1_neon(q2, q3);
+
+ q2 = vld1q_u16(a + 16);
+ q3 = vld1q_u16(b + 16);
+
+ sse += highbd_sse_W8x1_neon(q2, q3);
+
+ q2 = vld1q_u16(a + 24);
+ q3 = vld1q_u16(b + 24);
+
+ sse += highbd_sse_W8x1_neon(q2, q3);
+
+ a += a_stride;
+ b += b_stride;
+ }
+ break;
+ case 64:
+ for (y = 0; y < height; y++) {
+ q2 = vld1q_u16(a);
+ q3 = vld1q_u16(b);
+
+ sse += highbd_sse_W8x1_neon(q2, q3);
+
+ q2 = vld1q_u16(a + 8);
+ q3 = vld1q_u16(b + 8);
+
+ sse += highbd_sse_W8x1_neon(q2, q3);
+
+ q2 = vld1q_u16(a + 16);
+ q3 = vld1q_u16(b + 16);
+
+ sse += highbd_sse_W8x1_neon(q2, q3);
+
+ q2 = vld1q_u16(a + 24);
+ q3 = vld1q_u16(b + 24);
+
+ sse += highbd_sse_W8x1_neon(q2, q3);
+
+ q2 = vld1q_u16(a + 32);
+ q3 = vld1q_u16(b + 32);
+
+ sse += highbd_sse_W8x1_neon(q2, q3);
+
+ q2 = vld1q_u16(a + 40);
+ q3 = vld1q_u16(b + 40);
+
+ sse += highbd_sse_W8x1_neon(q2, q3);
+
+ q2 = vld1q_u16(a + 48);
+ q3 = vld1q_u16(b + 48);
+
+ sse += highbd_sse_W8x1_neon(q2, q3);
+
+ q2 = vld1q_u16(a + 56);
+ q3 = vld1q_u16(b + 56);
+
+ sse += highbd_sse_W8x1_neon(q2, q3);
+
+ a += a_stride;
+ b += b_stride;
+ }
+ break;
+ case 128:
+ for (y = 0; y < height; y++) {
+ q2 = vld1q_u16(a);
+ q3 = vld1q_u16(b);
+
+ sse += highbd_sse_W8x1_neon(q2, q3);
+
+ q2 = vld1q_u16(a + 8);
+ q3 = vld1q_u16(b + 8);
+
+ sse += highbd_sse_W8x1_neon(q2, q3);
+
+ q2 = vld1q_u16(a + 16);
+ q3 = vld1q_u16(b + 16);
+
+ sse += highbd_sse_W8x1_neon(q2, q3);
+
+ q2 = vld1q_u16(a + 24);
+ q3 = vld1q_u16(b + 24);
+
+ sse += highbd_sse_W8x1_neon(q2, q3);
+
+ q2 = vld1q_u16(a + 32);
+ q3 = vld1q_u16(b + 32);
+
+ sse += highbd_sse_W8x1_neon(q2, q3);
+
+ q2 = vld1q_u16(a + 40);
+ q3 = vld1q_u16(b + 40);
+
+ sse += highbd_sse_W8x1_neon(q2, q3);
+
+ q2 = vld1q_u16(a + 48);
+ q3 = vld1q_u16(b + 48);
+
+ sse += highbd_sse_W8x1_neon(q2, q3);
+
+ q2 = vld1q_u16(a + 56);
+ q3 = vld1q_u16(b + 56);
+
+ sse += highbd_sse_W8x1_neon(q2, q3);
+
+ q2 = vld1q_u16(a + 64);
+ q3 = vld1q_u16(b + 64);
+
+ sse += highbd_sse_W8x1_neon(q2, q3);
+
+ q2 = vld1q_u16(a + 72);
+ q3 = vld1q_u16(b + 72);
+
+ sse += highbd_sse_W8x1_neon(q2, q3);
+
+ q2 = vld1q_u16(a + 80);
+ q3 = vld1q_u16(b + 80);
+
+ sse += highbd_sse_W8x1_neon(q2, q3);
+
+ q2 = vld1q_u16(a + 88);
+ q3 = vld1q_u16(b + 88);
+
+ sse += highbd_sse_W8x1_neon(q2, q3);
+
+ q2 = vld1q_u16(a + 96);
+ q3 = vld1q_u16(b + 96);
+
+ sse += highbd_sse_W8x1_neon(q2, q3);
+
+ q2 = vld1q_u16(a + 104);
+ q3 = vld1q_u16(b + 104);
+
+ sse += highbd_sse_W8x1_neon(q2, q3);
+
+ q2 = vld1q_u16(a + 112);
+ q3 = vld1q_u16(b + 112);
+
+ sse += highbd_sse_W8x1_neon(q2, q3);
+
+ q2 = vld1q_u16(a + 120);
+ q3 = vld1q_u16(b + 120);
+
+ sse += highbd_sse_W8x1_neon(q2, q3);
+ a += a_stride;
+ b += b_stride;
+ }
+ break;
+ default:
+
+ for (y = 0; y < height; y++) {
+ x = width;
+ while (x > 0) {
+ addinc = width - x;
+ q2 = vld1q_u16(a + addinc);
+ q3 = vld1q_u16(b + addinc);
+ if (x < 8) {
+ dx = x;
+ q4 = vld1q_dup_u16(&dx);
+ q5 = vcltq_u16(q0, q4);
+ q2 = vandq_u16(q2, q5);
+ q3 = vandq_u16(q3, q5);
+ }
+ sse += highbd_sse_W8x1_neon(q2, q3);
+ x -= 8;
+ }
+ a += a_stride;
+ b += b_stride;
+ }
+ }
+ return (int64_t)sse;
+}
+#endif
diff --git a/test/sum_squares_test.cc b/test/sum_squares_test.cc
index 4644e71..8845466 100644
--- a/test/sum_squares_test.cc
+++ b/test/sum_squares_test.cc
@@ -389,6 +389,9 @@
#if HAVE_NEON
TestSSEFuncs sse_neon[] = {
TestSSEFuncs(&aom_sse_c, &aom_sse_neon),
+#if CONFIG_AV1_HIGHBITDEPTH
+ TestSSEFuncs(&aom_highbd_sse_c, &aom_highbd_sse_neon)
+#endif
};
INSTANTIATE_TEST_SUITE_P(NEON, SSETest,
Combine(ValuesIn(sse_neon), Range(4, 129, 4)));