Add pixel_proj_error simd code
1. add low bitdepth pixel_proj_error sse4_1 code
2. add low bitdepth pixel_proj_error avx2 code
3. add low bitdepth pixel_proj_error unittest
Speed up about 0.6% without rd change
test sequence: BasketballDrill_832x480_50.y4m
test command line:./aomenc --cpu-used=1 --psnr -D \
-q --end-usage=vbr --target-bitrate=800 --limit=20 \
BasketballDrill_832x480_50.y4m -otest.webm
Change-Id: Ia50e6c7ff70da24e5564cc52069e54b585d7c8c1
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index 0bff779..f149ab0 100755
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -253,6 +253,9 @@
add_proto qw/void av1_compute_stats/, "int wiener_win, const uint8_t *dgd8, const uint8_t *src8, int h_start, int h_end, int v_start, int v_end, int dgd_stride, int src_stride, double *M, double *H";
specialize qw/av1_compute_stats sse4_1 avx2/;
+
+ add_proto qw/int64_t av1_lowbd_pixel_proj_error/, " const uint8_t *src8, int width, int height, int src_stride, const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride, int32_t *flt1, int flt1_stride, int xq[2], const sgr_params_type *params";
+ specialize qw/av1_lowbd_pixel_proj_error sse4_1 avx2/;
}
# end encoder functions
diff --git a/av1/encoder/pickrst.c b/av1/encoder/pickrst.c
index 86bfe04..f6a1bae 100644
--- a/av1/encoder/pickrst.c
+++ b/av1/encoder/pickrst.c
@@ -181,6 +181,77 @@
return sse_restoration_unit(limits, rsc->src, rsc->dst, plane, highbd);
}
+int64_t av1_lowbd_pixel_proj_error_c(const uint8_t *src8, int width, int height,
+ int src_stride, const uint8_t *dat8,
+ int dat_stride, int32_t *flt0,
+ int flt0_stride, int32_t *flt1,
+ int flt1_stride, int xq[2],
+ const sgr_params_type *params) {
+ int i, j;
+ const uint8_t *src = src8;
+ const uint8_t *dat = dat8;
+ int64_t err = 0;
+ if (params->r[0] > 0 && params->r[1] > 0) {
+ for (i = 0; i < height; ++i) {
+ for (j = 0; j < width; ++j) {
+ assert(flt1[j] < (1 << 15) && flt1[j] > -(1 << 15));
+ assert(flt0[j] < (1 << 15) && flt0[j] > -(1 << 15));
+ const int32_t u = (int32_t)(dat[j] << SGRPROJ_RST_BITS);
+ int32_t v = u << SGRPROJ_PRJ_BITS;
+ v += xq[0] * (flt0[j] - u) + xq[1] * (flt1[j] - u);
+ const int32_t e =
+ ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) - src[j];
+ err += e * e;
+ }
+ dat += dat_stride;
+ src += src_stride;
+ flt0 += flt0_stride;
+ flt1 += flt1_stride;
+ }
+ } else if (params->r[0] > 0) {
+ for (i = 0; i < height; ++i) {
+ for (j = 0; j < width; ++j) {
+ assert(flt0[j] < (1 << 15) && flt0[j] > -(1 << 15));
+ const int32_t u = (int32_t)(dat[j] << SGRPROJ_RST_BITS);
+ int32_t v = u << SGRPROJ_PRJ_BITS;
+ v += xq[0] * (flt0[j] - u);
+ const int32_t e =
+ ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) - src[j];
+ err += e * e;
+ }
+ dat += dat_stride;
+ src += src_stride;
+ flt0 += flt0_stride;
+ }
+ } else if (params->r[1] > 0) {
+ for (i = 0; i < height; ++i) {
+ for (j = 0; j < width; ++j) {
+ assert(flt1[j] < (1 << 15) && flt1[j] > -(1 << 15));
+ const int32_t u = (int32_t)(dat[j] << SGRPROJ_RST_BITS);
+ int32_t v = u << SGRPROJ_PRJ_BITS;
+ v += xq[1] * (flt1[j] - u);
+ const int32_t e =
+ ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) - src[j];
+ err += e * e;
+ }
+ dat += dat_stride;
+ src += src_stride;
+ flt1 += flt1_stride;
+ }
+ } else {
+ for (i = 0; i < height; ++i) {
+ for (j = 0; j < width; ++j) {
+ const int32_t e = (int32_t)(dat[j]) - src[j];
+ err += e * e;
+ }
+ dat += dat_stride;
+ src += src_stride;
+ }
+ }
+
+ return err;
+}
+
static int64_t get_pixel_proj_error(const uint8_t *src8, int width, int height,
int src_stride, const uint8_t *dat8,
int dat_stride, int use_highbitdepth,
@@ -192,21 +263,9 @@
int xq[2];
decode_xq(xqd, xq, params);
if (!use_highbitdepth) {
- const uint8_t *src = src8;
- const uint8_t *dat = dat8;
- for (i = 0; i < height; ++i) {
- for (j = 0; j < width; ++j) {
- const int32_t u =
- (int32_t)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
- int32_t v = u << SGRPROJ_PRJ_BITS;
- if (params->r[0] > 0) v += xq[0] * (flt0[i * flt0_stride + j] - u);
- if (params->r[1] > 0) v += xq[1] * (flt1[i * flt1_stride + j] - u);
- const int32_t e =
- ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) -
- src[i * src_stride + j];
- err += e * e;
- }
- }
+ err = av1_lowbd_pixel_proj_error(src8, width, height, src_stride, dat8,
+ dat_stride, flt0, flt0_stride, flt1,
+ flt1_stride, xq, params);
} else {
const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
diff --git a/av1/encoder/x86/pickrst_avx2.c b/av1/encoder/x86/pickrst_avx2.c
index 31a063b..548f6c3 100644
--- a/av1/encoder/x86/pickrst_avx2.c
+++ b/av1/encoder/x86/pickrst_avx2.c
@@ -227,3 +227,176 @@
dgd_stride, src_stride, M, H);
}
}
+
+static INLINE __m256i pair_set_epi16(uint16_t a, uint16_t b) {
+ return _mm256_set1_epi32(
+ (int32_t)(((uint16_t)(a)) | (((uint32_t)(b)) << 16)));
+}
+
+int64_t av1_lowbd_pixel_proj_error_avx2(
+ const uint8_t *src8, int width, int height, int src_stride,
+ const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride,
+ int32_t *flt1, int flt1_stride, int xq[2], const sgr_params_type *params) {
+ int i, j, k;
+ const int32_t shift = SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS;
+ const __m256i rounding = _mm256_set1_epi32(1 << (shift - 1));
+ __m256i sum64 = _mm256_setzero_si256();
+ const uint8_t *src = src8;
+ const uint8_t *dat = dat8;
+ int64_t err = 0;
+ if (params->r[0] > 0 && params->r[1] > 0) {
+ __m256i xq_coeff = pair_set_epi16(xq[0], xq[1]);
+ for (i = 0; i < height; ++i) {
+ __m256i sum32 = _mm256_setzero_si256();
+ for (j = 0; j <= width - 16; j += 16) {
+ const __m256i d0 = _mm256_cvtepu8_epi16(xx_loadu_128(dat + j));
+ const __m256i s0 = _mm256_cvtepu8_epi16(xx_loadu_128(src + j));
+ const __m256i flt0_16b = _mm256_permute4x64_epi64(
+ _mm256_packs_epi32(yy_loadu_256(flt0 + j),
+ yy_loadu_256(flt0 + j + 8)),
+ 0xd8);
+ const __m256i flt1_16b = _mm256_permute4x64_epi64(
+ _mm256_packs_epi32(yy_loadu_256(flt1 + j),
+ yy_loadu_256(flt1 + j + 8)),
+ 0xd8);
+ const __m256i u0 = _mm256_slli_epi16(d0, SGRPROJ_RST_BITS);
+ const __m256i flt0_0_sub_u = _mm256_sub_epi16(flt0_16b, u0);
+ const __m256i flt1_0_sub_u = _mm256_sub_epi16(flt1_16b, u0);
+ const __m256i v0 = _mm256_madd_epi16(
+ xq_coeff, _mm256_unpacklo_epi16(flt0_0_sub_u, flt1_0_sub_u));
+ const __m256i v1 = _mm256_madd_epi16(
+ xq_coeff, _mm256_unpackhi_epi16(flt0_0_sub_u, flt1_0_sub_u));
+ const __m256i vr0 =
+ _mm256_srai_epi32(_mm256_add_epi32(v0, rounding), shift);
+ const __m256i vr1 =
+ _mm256_srai_epi32(_mm256_add_epi32(v1, rounding), shift);
+ const __m256i e0 = _mm256_sub_epi16(
+ _mm256_add_epi16(_mm256_packs_epi32(vr0, vr1), d0), s0);
+ const __m256i err0 = _mm256_madd_epi16(e0, e0);
+ sum32 = _mm256_add_epi32(sum32, err0);
+ }
+ for (k = j; k < width; ++k) {
+ const int32_t u = (int32_t)(dat[k] << SGRPROJ_RST_BITS);
+ int32_t v = xq[0] * (flt0[k] - u) + xq[1] * (flt1[k] - u);
+ const int32_t e = ROUND_POWER_OF_TWO(v, shift) + dat[k] - src[k];
+ err += e * e;
+ }
+ dat += dat_stride;
+ src += src_stride;
+ flt0 += flt0_stride;
+ flt1 += flt1_stride;
+ const __m256i sum64_0 =
+ _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sum32));
+ const __m256i sum64_1 =
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(sum32, 1));
+ sum64 = _mm256_add_epi64(sum64, sum64_0);
+ sum64 = _mm256_add_epi64(sum64, sum64_1);
+ }
+ } else if (params->r[0] > 0) {
+ __m256i xq_coeff = pair_set_epi16(xq[0], -(xq[0] << SGRPROJ_RST_BITS));
+ for (i = 0; i < height; ++i) {
+ __m256i sum32 = _mm256_setzero_si256();
+ for (j = 0; j <= width - 16; j += 16) {
+ const __m256i d0 = _mm256_cvtepu8_epi16(xx_loadu_128(dat + j));
+ const __m256i s0 = _mm256_cvtepu8_epi16(xx_loadu_128(src + j));
+ const __m256i flt0_16b = _mm256_permute4x64_epi64(
+ _mm256_packs_epi32(yy_loadu_256(flt0 + j),
+ yy_loadu_256(flt0 + j + 8)),
+ 0xd8);
+ const __m256i v0 =
+ _mm256_madd_epi16(xq_coeff, _mm256_unpacklo_epi16(flt0_16b, d0));
+ const __m256i v1 =
+ _mm256_madd_epi16(xq_coeff, _mm256_unpackhi_epi16(flt0_16b, d0));
+ const __m256i vr0 =
+ _mm256_srai_epi32(_mm256_add_epi32(v0, rounding), shift);
+ const __m256i vr1 =
+ _mm256_srai_epi32(_mm256_add_epi32(v1, rounding), shift);
+ const __m256i e0 = _mm256_sub_epi16(
+ _mm256_add_epi16(_mm256_packs_epi32(vr0, vr1), d0), s0);
+ const __m256i err0 = _mm256_madd_epi16(e0, e0);
+ sum32 = _mm256_add_epi32(sum32, err0);
+ }
+ for (k = j; k < width; ++k) {
+ const int32_t u = (int32_t)(dat[k] << SGRPROJ_RST_BITS);
+ int32_t v = xq[0] * (flt0[k] - u);
+ const int32_t e = ROUND_POWER_OF_TWO(v, shift) + dat[k] - src[k];
+ err += e * e;
+ }
+ dat += dat_stride;
+ src += src_stride;
+ flt0 += flt0_stride;
+ const __m256i sum64_0 =
+ _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sum32));
+ const __m256i sum64_1 =
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(sum32, 1));
+ sum64 = _mm256_add_epi64(sum64, sum64_0);
+ sum64 = _mm256_add_epi64(sum64, sum64_1);
+ }
+ } else if (params->r[1] > 0) {
+ __m256i xq_coeff = pair_set_epi16(xq[1], -(xq[1] << SGRPROJ_RST_BITS));
+ for (i = 0; i < height; ++i) {
+ __m256i sum32 = _mm256_setzero_si256();
+ for (j = 0; j <= width - 16; j += 16) {
+ const __m256i d0 = _mm256_cvtepu8_epi16(xx_loadu_128(dat + j));
+ const __m256i s0 = _mm256_cvtepu8_epi16(xx_loadu_128(src + j));
+ const __m256i flt1_16b = _mm256_permute4x64_epi64(
+ _mm256_packs_epi32(yy_loadu_256(flt1 + j),
+ yy_loadu_256(flt1 + j + 8)),
+ 0xd8);
+ const __m256i v0 =
+ _mm256_madd_epi16(xq_coeff, _mm256_unpacklo_epi16(flt1_16b, d0));
+ const __m256i v1 =
+ _mm256_madd_epi16(xq_coeff, _mm256_unpackhi_epi16(flt1_16b, d0));
+ const __m256i vr0 =
+ _mm256_srai_epi32(_mm256_add_epi32(v0, rounding), shift);
+ const __m256i vr1 =
+ _mm256_srai_epi32(_mm256_add_epi32(v1, rounding), shift);
+ const __m256i e0 = _mm256_sub_epi16(
+ _mm256_add_epi16(_mm256_packs_epi32(vr0, vr1), d0), s0);
+ const __m256i err0 = _mm256_madd_epi16(e0, e0);
+ sum32 = _mm256_add_epi32(sum32, err0);
+ }
+ for (k = j; k < width; ++k) {
+ const int32_t u = (int32_t)(dat[k] << SGRPROJ_RST_BITS);
+ int32_t v = xq[1] * (flt1[k] - u);
+ const int32_t e = ROUND_POWER_OF_TWO(v, shift) + dat[k] - src[k];
+ err += e * e;
+ }
+ dat += dat_stride;
+ src += src_stride;
+ flt1 += flt1_stride;
+ const __m256i sum64_0 =
+ _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sum32));
+ const __m256i sum64_1 =
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(sum32, 1));
+ sum64 = _mm256_add_epi64(sum64, sum64_0);
+ sum64 = _mm256_add_epi64(sum64, sum64_1);
+ }
+ } else {
+ __m256i sum32 = _mm256_setzero_si256();
+ for (i = 0; i < height; ++i) {
+ for (j = 0; j <= width - 16; j += 16) {
+ const __m256i d0 = _mm256_cvtepu8_epi16(xx_loadu_128(dat + j));
+ const __m256i s0 = _mm256_cvtepu8_epi16(xx_loadu_128(src + j));
+ const __m256i diff0 = _mm256_sub_epi16(d0, s0);
+ const __m256i err0 = _mm256_madd_epi16(diff0, diff0);
+ sum32 = _mm256_add_epi32(sum32, err0);
+ }
+ for (k = j; k < width; ++k) {
+ const int32_t e = (int32_t)(dat[k]) - src[k];
+ err += e * e;
+ }
+ dat += dat_stride;
+ src += src_stride;
+ }
+ const __m256i sum64_0 =
+ _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sum32));
+ const __m256i sum64_1 =
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(sum32, 1));
+ sum64 = _mm256_add_epi64(sum64_0, sum64_1);
+ }
+ int64_t sum[4];
+ yy_storeu_256(sum, sum64);
+ err += sum[0] + sum[1] + sum[2] + sum[3];
+ return err;
+}
diff --git a/av1/encoder/x86/pickrst_sse4.c b/av1/encoder/x86/pickrst_sse4.c
index 4f6cde0..04e4d1a 100644
--- a/av1/encoder/x86/pickrst_sse4.c
+++ b/av1/encoder/x86/pickrst_sse4.c
@@ -230,3 +230,160 @@
dgd_stride, src_stride, M, H);
}
}
+
+static INLINE __m128i pair_set_epi16(uint16_t a, uint16_t b) {
+ return _mm_set1_epi32((int32_t)(((uint16_t)(a)) | (((uint32_t)(b)) << 16)));
+}
+
+int64_t av1_lowbd_pixel_proj_error_sse4_1(
+ const uint8_t *src8, int width, int height, int src_stride,
+ const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride,
+ int32_t *flt1, int flt1_stride, int xq[2], const sgr_params_type *params) {
+ int i, j, k;
+ const int32_t shift = SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS;
+ const __m128i rounding = _mm_set1_epi32(1 << (shift - 1));
+ __m128i sum64 = _mm_setzero_si128();
+ const uint8_t *src = src8;
+ const uint8_t *dat = dat8;
+ int64_t err = 0;
+ if (params->r[0] > 0 && params->r[1] > 0) {
+ __m128i xq_coeff = pair_set_epi16(xq[0], xq[1]);
+ for (i = 0; i < height; ++i) {
+ __m128i sum32 = _mm_setzero_si128();
+ for (j = 0; j < width - 8; j += 8) {
+ const __m128i d0 = _mm_cvtepu8_epi16(xx_loadl_64(dat + j));
+ const __m128i s0 = _mm_cvtepu8_epi16(xx_loadl_64(src + j));
+ const __m128i flt0_16b =
+ _mm_packs_epi32(xx_loadu_128(flt0 + j), xx_loadu_128(flt0 + j + 4));
+ const __m128i flt1_16b =
+ _mm_packs_epi32(xx_loadu_128(flt1 + j), xx_loadu_128(flt1 + j + 4));
+ const __m128i u0 = _mm_slli_epi16(d0, SGRPROJ_RST_BITS);
+ const __m128i flt0_0_sub_u = _mm_sub_epi16(flt0_16b, u0);
+ const __m128i flt1_0_sub_u = _mm_sub_epi16(flt1_16b, u0);
+ const __m128i v0 = _mm_madd_epi16(
+ xq_coeff, _mm_unpacklo_epi16(flt0_0_sub_u, flt1_0_sub_u));
+ const __m128i v1 = _mm_madd_epi16(
+ xq_coeff, _mm_unpackhi_epi16(flt0_0_sub_u, flt1_0_sub_u));
+ const __m128i vr0 = _mm_srai_epi32(_mm_add_epi32(v0, rounding), shift);
+ const __m128i vr1 = _mm_srai_epi32(_mm_add_epi32(v1, rounding), shift);
+ const __m128i e0 =
+ _mm_sub_epi16(_mm_add_epi16(_mm_packs_epi32(vr0, vr1), d0), s0);
+ const __m128i err0 = _mm_madd_epi16(e0, e0);
+ sum32 = _mm_add_epi32(sum32, err0);
+ }
+ for (k = j; k < width; ++k) {
+ const int32_t u = (int32_t)(dat[k] << SGRPROJ_RST_BITS);
+ int32_t v = xq[0] * (flt0[k] - u) + xq[1] * (flt1[k] - u);
+ const int32_t e = ROUND_POWER_OF_TWO(v, shift) + dat[k] - src[k];
+ err += e * e;
+ }
+ dat += dat_stride;
+ src += src_stride;
+ flt0 += flt0_stride;
+ flt1 += flt1_stride;
+ const __m128i sum64_0 = _mm_cvtepi32_epi64(sum32);
+ const __m128i sum64_1 = _mm_cvtepi32_epi64(_mm_srli_si128(sum32, 8));
+ sum64 = _mm_add_epi64(sum64, sum64_0);
+ sum64 = _mm_add_epi64(sum64, sum64_1);
+ }
+ } else if (params->r[0] > 0) {
+ __m128i xq_coeff = pair_set_epi16(xq[0], -(xq[0] << SGRPROJ_RST_BITS));
+ for (i = 0; i < height; ++i) {
+ __m128i sum32 = _mm_setzero_si128();
+ for (j = 0; j < width - 8; j += 8) {
+ const __m128i d0 = _mm_cvtepu8_epi16(xx_loadl_64(dat + j));
+ const __m128i s0 = _mm_cvtepu8_epi16(xx_loadl_64(src + j));
+ const __m128i flt0_16b =
+ _mm_packs_epi32(xx_loadu_128(flt0 + j), xx_loadu_128(flt0 + j + 4));
+ const __m128i v0 =
+ _mm_madd_epi16(xq_coeff, _mm_unpacklo_epi16(flt0_16b, d0));
+ const __m128i v1 =
+ _mm_madd_epi16(xq_coeff, _mm_unpackhi_epi16(flt0_16b, d0));
+ const __m128i vr0 = _mm_srai_epi32(_mm_add_epi32(v0, rounding), shift);
+ const __m128i vr1 = _mm_srai_epi32(_mm_add_epi32(v1, rounding), shift);
+ const __m128i e0 =
+ _mm_sub_epi16(_mm_add_epi16(_mm_packs_epi32(vr0, vr1), d0), s0);
+ const __m128i err0 = _mm_madd_epi16(e0, e0);
+ sum32 = _mm_add_epi32(sum32, err0);
+ }
+ for (k = j; k < width; ++k) {
+ const int32_t u = (int32_t)(dat[k] << SGRPROJ_RST_BITS);
+ int32_t v = xq[0] * (flt0[k] - u);
+ const int32_t e = ROUND_POWER_OF_TWO(v, shift) + dat[k] - src[k];
+ err += e * e;
+ }
+ dat += dat_stride;
+ src += src_stride;
+ flt0 += flt0_stride;
+ const __m128i sum64_0 = _mm_cvtepi32_epi64(sum32);
+ const __m128i sum64_1 = _mm_cvtepi32_epi64(_mm_srli_si128(sum32, 8));
+ sum64 = _mm_add_epi64(sum64, sum64_0);
+ sum64 = _mm_add_epi64(sum64, sum64_1);
+ }
+ } else if (params->r[1] > 0) {
+ __m128i xq_coeff = pair_set_epi16(xq[1], -(xq[1] << SGRPROJ_RST_BITS));
+ for (i = 0; i < height; ++i) {
+ __m128i sum32 = _mm_setzero_si128();
+ for (j = 0; j < width - 8; j += 8) {
+ const __m128i d0 = _mm_cvtepu8_epi16(xx_loadl_64(dat + j));
+ const __m128i s0 = _mm_cvtepu8_epi16(xx_loadl_64(src + j));
+ const __m128i flt1_16b =
+ _mm_packs_epi32(xx_loadu_128(flt1 + j), xx_loadu_128(flt1 + j + 4));
+ const __m128i v0 =
+ _mm_madd_epi16(xq_coeff, _mm_unpacklo_epi16(flt1_16b, d0));
+ const __m128i v1 =
+ _mm_madd_epi16(xq_coeff, _mm_unpackhi_epi16(flt1_16b, d0));
+ const __m128i vr0 = _mm_srai_epi32(_mm_add_epi32(v0, rounding), shift);
+ const __m128i vr1 = _mm_srai_epi32(_mm_add_epi32(v1, rounding), shift);
+ const __m128i e0 =
+ _mm_sub_epi16(_mm_add_epi16(_mm_packs_epi32(vr0, vr1), d0), s0);
+ const __m128i err0 = _mm_madd_epi16(e0, e0);
+ sum32 = _mm_add_epi32(sum32, err0);
+ }
+ for (k = j; k < width; ++k) {
+ const int32_t u = (int32_t)(dat[k] << SGRPROJ_RST_BITS);
+ int32_t v = xq[1] * (flt1[k] - u);
+ const int32_t e = ROUND_POWER_OF_TWO(v, shift) + dat[k] - src[k];
+ err += e * e;
+ }
+ dat += dat_stride;
+ src += src_stride;
+ flt1 += flt1_stride;
+ const __m128i sum64_0 = _mm_cvtepi32_epi64(sum32);
+ const __m128i sum64_1 = _mm_cvtepi32_epi64(_mm_srli_si128(sum32, 8));
+ sum64 = _mm_add_epi64(sum64, sum64_0);
+ sum64 = _mm_add_epi64(sum64, sum64_1);
+ }
+ } else {
+ __m128i sum32 = _mm_setzero_si128();
+ for (i = 0; i < height; ++i) {
+ for (j = 0; j < width - 16; j += 16) {
+ const __m128i d = xx_loadu_128(dat + j);
+ const __m128i s = xx_loadu_128(src + j);
+ const __m128i d0 = _mm_cvtepu8_epi16(d);
+ const __m128i d1 = _mm_cvtepu8_epi16(_mm_srli_si128(d, 8));
+ const __m128i s0 = _mm_cvtepu8_epi16(s);
+ const __m128i s1 = _mm_cvtepu8_epi16(_mm_srli_si128(s, 8));
+ const __m128i diff0 = _mm_sub_epi16(d0, s0);
+ const __m128i diff1 = _mm_sub_epi16(d1, s1);
+ const __m128i err0 = _mm_madd_epi16(diff0, diff0);
+ const __m128i err1 = _mm_madd_epi16(diff1, diff1);
+ sum32 = _mm_add_epi32(sum32, err0);
+ sum32 = _mm_add_epi32(sum32, err1);
+ }
+ for (k = j; k < width; ++k) {
+ const int32_t e = (int32_t)(dat[k]) - src[k];
+ err += e * e;
+ }
+ dat += dat_stride;
+ src += src_stride;
+ }
+ const __m128i sum64_0 = _mm_cvtepi32_epi64(sum32);
+ const __m128i sum64_1 = _mm_cvtepi32_epi64(_mm_srli_si128(sum32, 8));
+ sum64 = _mm_add_epi64(sum64_0, sum64_1);
+ }
+ int64_t sum[2];
+ xx_storeu_128(sum, sum64);
+ err += sum[0] + sum[1];
+ return err;
+}
diff --git a/test/pickrst_test.cc b/test/pickrst_test.cc
new file mode 100644
index 0000000..040e8e8
--- /dev/null
+++ b/test/pickrst_test.cc
@@ -0,0 +1,187 @@
+/*
+ * Copyright (c) 2018, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#include "third_party/googletest/src/googletest/include/gtest/gtest.h"
+
+#include "test/function_equivalence_test.h"
+#include "test/register_state_check.h"
+
+#include "config/aom_config.h"
+#include "config/aom_dsp_rtcd.h"
+
+#include "aom/aom_integer.h"
+#include "av1/encoder/pickrst.h"
+using libaom_test::FunctionEquivalenceTest;
+
+#define MAX_DATA_BLOCK 384
+
+namespace {
+static const int kIterations = 100;
+
+typedef int64_t (*lowbd_pixel_proj_error_func)(
+ const uint8_t *src8, int width, int height, int src_stride,
+ const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride,
+ int32_t *flt1, int flt1_stride, int xq[2], const sgr_params_type *params);
+
+typedef libaom_test::FuncParam<lowbd_pixel_proj_error_func> TestFuncs;
+
+////////////////////////////////////////////////////////////////////////////////
+// 8 bit
+////////////////////////////////////////////////////////////////////////////////
+
+typedef ::testing::tuple<const lowbd_pixel_proj_error_func>
+ PixelProjErrorTestParam;
+
+class PixelProjErrorTest
+ : public ::testing::TestWithParam<PixelProjErrorTestParam> {
+ public:
+ virtual void SetUp() {
+ target_func_ = GET_PARAM(0);
+ src_ = (uint8_t *)(aom_malloc(MAX_DATA_BLOCK * MAX_DATA_BLOCK *
+ sizeof(uint8_t)));
+ dgd_ = (uint8_t *)(aom_malloc(MAX_DATA_BLOCK * MAX_DATA_BLOCK *
+ sizeof(uint8_t)));
+ flt0_ = (int32_t *)(aom_malloc(MAX_DATA_BLOCK * MAX_DATA_BLOCK *
+ sizeof(int32_t)));
+ flt1_ = (int32_t *)(aom_malloc(MAX_DATA_BLOCK * MAX_DATA_BLOCK *
+ sizeof(int32_t)));
+ }
+ virtual void TearDown() {
+ aom_free(src_);
+ aom_free(dgd_);
+ aom_free(flt0_);
+ aom_free(flt1_);
+ }
+ void runPixelProjErrorTest(int32_t run_times);
+ void runPixelProjErrorTest_ExtremeValues();
+
+ private:
+ lowbd_pixel_proj_error_func target_func_;
+ ACMRandom rng_;
+ uint8_t *src_;
+ uint8_t *dgd_;
+ int32_t *flt0_;
+ int32_t *flt1_;
+};
+
+void PixelProjErrorTest::runPixelProjErrorTest(int32_t run_times) {
+ int h_end = run_times != 1 ? 128 : (rng_.Rand16() % MAX_DATA_BLOCK) + 1;
+ int v_end = run_times != 1 ? 128 : (rng_.Rand16() % MAX_DATA_BLOCK) + 1;
+ const int dgd_stride = MAX_DATA_BLOCK;
+ const int src_stride = MAX_DATA_BLOCK;
+ const int flt0_stride = MAX_DATA_BLOCK;
+ const int flt1_stride = MAX_DATA_BLOCK;
+ sgr_params_type params;
+ int xq[2];
+ const int iters = run_times == 1 ? kIterations : 4;
+ for (int iter = 0; iter < iters && !HasFatalFailure(); ++iter) {
+ int64_t err_ref = 0, err_test = 1;
+ for (int i = 0; i < MAX_DATA_BLOCK * MAX_DATA_BLOCK; ++i) {
+ dgd_[i] = rng_.Rand8();
+ src_[i] = rng_.Rand8();
+ flt0_[i] = rng_.Rand15Signed();
+ flt1_[i] = rng_.Rand15Signed();
+ }
+ xq[0] = rng_.Rand8() % (1 << SGRPROJ_PRJ_BITS);
+ xq[1] = rng_.Rand8() % (1 << SGRPROJ_PRJ_BITS);
+ params.r[0] = run_times == 1 ? (rng_.Rand8() % MAX_RADIUS) : (iter % 2);
+ params.r[1] = run_times == 1 ? (rng_.Rand8() % MAX_RADIUS) : (iter / 2);
+ params.s[0] = run_times == 1 ? (rng_.Rand8() % MAX_RADIUS) : (iter % 2);
+ params.s[1] = run_times == 1 ? (rng_.Rand8() % MAX_RADIUS) : (iter / 2);
+ uint8_t *dgd = dgd_;
+ uint8_t *src = src_;
+
+ aom_usec_timer timer;
+ aom_usec_timer_start(&timer);
+ for (int i = 0; i < run_times; ++i) {
+ err_ref = av1_lowbd_pixel_proj_error_c(src, h_end, v_end, src_stride, dgd,
+ dgd_stride, flt0_, flt0_stride,
+ flt1_, flt1_stride, xq, ¶ms);
+ }
+ aom_usec_timer_mark(&timer);
+ const double time1 = static_cast<double>(aom_usec_timer_elapsed(&timer));
+ aom_usec_timer_start(&timer);
+ for (int i = 0; i < run_times; ++i) {
+ err_test =
+ target_func_(src, h_end, v_end, src_stride, dgd, dgd_stride, flt0_,
+ flt0_stride, flt1_, flt1_stride, xq, ¶ms);
+ }
+ aom_usec_timer_mark(&timer);
+ const double time2 = static_cast<double>(aom_usec_timer_elapsed(&timer));
+ if (run_times > 10) {
+ printf("r0 %d r1 %d %3dx%-3d:%7.2f/%7.2fns (%3.2f)\n", params.r[0],
+ params.r[1], h_end, v_end, time1, time2, time1 / time2);
+ }
+ ASSERT_EQ(err_ref, err_test);
+ }
+}
+
+void PixelProjErrorTest::runPixelProjErrorTest_ExtremeValues() {
+ const int h_start = 0;
+ int h_end = 192;
+ const int v_start = 0;
+ int v_end = 192;
+ const int dgd_stride = MAX_DATA_BLOCK;
+ const int src_stride = MAX_DATA_BLOCK;
+ const int flt0_stride = MAX_DATA_BLOCK;
+ const int flt1_stride = MAX_DATA_BLOCK;
+ sgr_params_type params;
+ int xq[2];
+ const int iters = kIterations;
+ for (int iter = 0; iter < iters && !HasFatalFailure(); ++iter) {
+ int64_t err_ref = 0, err_test = 1;
+ for (int i = 0; i < MAX_DATA_BLOCK * MAX_DATA_BLOCK; ++i) {
+ dgd_[i] = 0;
+ src_[i] = 255;
+ flt0_[i] = rng_.Rand15Signed();
+ flt1_[i] = rng_.Rand15Signed();
+ }
+ xq[0] = rng_.Rand8() % (1 << SGRPROJ_PRJ_BITS);
+ xq[1] = rng_.Rand8() % (1 << SGRPROJ_PRJ_BITS);
+ params.r[0] = rng_.Rand8() % MAX_RADIUS;
+ params.r[1] = rng_.Rand8() % MAX_RADIUS;
+ params.s[0] = rng_.Rand8() % MAX_RADIUS;
+ params.s[1] = rng_.Rand8() % MAX_RADIUS;
+ uint8_t *dgd = dgd_;
+ uint8_t *src = src_;
+
+ err_ref = av1_lowbd_pixel_proj_error_c(
+ src, h_end - h_start, v_end - v_start, src_stride, dgd, dgd_stride,
+ flt0_, flt0_stride, flt1_, flt1_stride, xq, ¶ms);
+
+ err_test = target_func_(src, h_end - h_start, v_end - v_start, src_stride,
+ dgd, dgd_stride, flt0_, flt0_stride, flt1_,
+ flt1_stride, xq, ¶ms);
+
+ ASSERT_EQ(err_ref, err_test);
+ }
+}
+
+TEST_P(PixelProjErrorTest, RandomValues) { runPixelProjErrorTest(1); }
+
+TEST_P(PixelProjErrorTest, ExtremeValues) {
+ runPixelProjErrorTest_ExtremeValues();
+}
+
+TEST_P(PixelProjErrorTest, DISABLED_Speed) { runPixelProjErrorTest(200000); }
+
+#if HAVE_SSE4_1
+INSTANTIATE_TEST_CASE_P(SSE4_1, PixelProjErrorTest,
+ ::testing::Values(av1_lowbd_pixel_proj_error_sse4_1));
+#endif // HAVE_SSE4_1
+
+#if HAVE_AVX2
+
+INSTANTIATE_TEST_CASE_P(AVX2, PixelProjErrorTest,
+ ::testing::Values(av1_lowbd_pixel_proj_error_avx2));
+#endif // HAVE_AVX2
+
+} // namespace
diff --git a/test/test.cmake b/test/test.cmake
index a6b7eaa..1aa08d3 100644
--- a/test/test.cmake
+++ b/test/test.cmake
@@ -72,6 +72,7 @@
"${AOM_ROOT}/test/resize_test.cc"
"${AOM_ROOT}/test/scalability_test.cc"
"${AOM_ROOT}/test/wiener_test.cc"
+ "${AOM_ROOT}/test/pickrst_test.cc"
"${AOM_ROOT}/test/y4m_test.cc"
"${AOM_ROOT}/test/y4m_video_source.h"
"${AOM_ROOT}/test/yuv_video_source.h")