SSE4.1 and AVX2 highbd_pixel_proj_error
Add SSE4.1 and AVX2 implementations of the high bit-depth
pixel_proj_error function used in the self-guided filter encoder
SSE4.1 unit test speed-ups
r0=0 r1=0: 5.1x
r0=0 r1=1: 2.7x
r0=1 r1=0: 2.7x
r0=1 r1=1: 2.4x
AVX2 unit test speed-ups
r0=0 r1=0: 6.0x
r0=0 r1=1: 3.5x
r0=1 r1=0: 3.5x
r0=1 r1=1: 2.8x
Encoding 15 frames of ducks_take_off_1080p using AVX2:
cpu=5 bd=10 speed-up=3.6%
cpu=5 bd=12 speed-up=4.0%
cpu=3 bd=10 speed-up=2.4%
cpu=3 bd=12 speed-up=3.2%
cpu=2 bd=10 speed-up=1.0%
cpu=2 bd=12 speed-up=2.1%
Change-Id: I3aac63eeb8d1af33c1d6414dd910708d3dad1793
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index 1cd6b4a..481d6b8 100755
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -298,6 +298,9 @@
add_proto qw/int64_t av1_lowbd_pixel_proj_error/, " const uint8_t *src8, int width, int height, int src_stride, const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride, int32_t *flt1, int flt1_stride, int xq[2], const sgr_params_type *params";
specialize qw/av1_lowbd_pixel_proj_error sse4_1 avx2/;
+
+ add_proto qw/int64_t av1_highbd_pixel_proj_error/, " const uint8_t *src8, int width, int height, int src_stride, const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride, int32_t *flt1, int flt1_stride, int xq[2], const sgr_params_type *params";
+ specialize qw/av1_highbd_pixel_proj_error sse4_1 avx2/;
}
# end encoder functions
diff --git a/av1/encoder/pickrst.c b/av1/encoder/pickrst.c
index 4faa849..2858d16 100644
--- a/av1/encoder/pickrst.c
+++ b/av1/encoder/pickrst.c
@@ -255,88 +255,97 @@
return err;
}
+int64_t av1_highbd_pixel_proj_error_c(const uint8_t *src8, int width,
+ int height, int src_stride,
+ const uint8_t *dat8, int dat_stride,
+ int32_t *flt0, int flt0_stride,
+ int32_t *flt1, int flt1_stride, int xq[2],
+ const sgr_params_type *params) {
+ const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
+ const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
+ int i, j;
+ int64_t err = 0;
+ const int32_t half = 1 << (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS - 1);
+ if (params->r[0] > 0 && params->r[1] > 0) {
+ int xq0 = xq[0];
+ int xq1 = xq[1];
+ for (i = 0; i < height; ++i) {
+ for (j = 0; j < width; ++j) {
+ const int32_t d = dat[j];
+ const int32_t s = src[j];
+ const int32_t u = (int32_t)(d << SGRPROJ_RST_BITS);
+ int32_t v0 = flt0[j] - u;
+ int32_t v1 = flt1[j] - u;
+ int32_t v = half;
+ v += xq0 * v0;
+ v += xq1 * v1;
+ const int32_t e = (v >> (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS)) + d - s;
+ err += e * e;
+ }
+ dat += dat_stride;
+ flt0 += flt0_stride;
+ flt1 += flt1_stride;
+ src += src_stride;
+ }
+ } else if (params->r[0] > 0 || params->r[1] > 0) {
+ int exq;
+ int32_t *flt;
+ int flt_stride;
+ if (params->r[0] > 0) {
+ exq = xq[0];
+ flt = flt0;
+ flt_stride = flt0_stride;
+ } else {
+ exq = xq[1];
+ flt = flt1;
+ flt_stride = flt1_stride;
+ }
+ for (i = 0; i < height; ++i) {
+ for (j = 0; j < width; ++j) {
+ const int32_t d = dat[j];
+ const int32_t s = src[j];
+ const int32_t u = (int32_t)(d << SGRPROJ_RST_BITS);
+ int32_t v = half;
+ v += exq * (flt[j] - u);
+ const int32_t e = (v >> (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS)) + d - s;
+ err += e * e;
+ }
+ dat += dat_stride;
+ flt += flt_stride;
+ src += src_stride;
+ }
+ } else {
+ for (i = 0; i < height; ++i) {
+ for (j = 0; j < width; ++j) {
+ const int32_t d = dat[j];
+ const int32_t s = src[j];
+ const int32_t e = d - s;
+ err += e * e;
+ }
+ dat += dat_stride;
+ src += src_stride;
+ }
+ }
+ return err;
+}
+
static int64_t get_pixel_proj_error(const uint8_t *src8, int width, int height,
int src_stride, const uint8_t *dat8,
int dat_stride, int use_highbitdepth,
int32_t *flt0, int flt0_stride,
int32_t *flt1, int flt1_stride, int *xqd,
const sgr_params_type *params) {
- int i, j;
- int64_t err = 0;
int xq[2];
decode_xq(xqd, xq, params);
if (!use_highbitdepth) {
- err = av1_lowbd_pixel_proj_error(src8, width, height, src_stride, dat8,
- dat_stride, flt0, flt0_stride, flt1,
- flt1_stride, xq, params);
+ return av1_lowbd_pixel_proj_error(src8, width, height, src_stride, dat8,
+ dat_stride, flt0, flt0_stride, flt1,
+ flt1_stride, xq, params);
} else {
- const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
- const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
- const int32_t half = 1 << (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS - 1);
- if (params->r[0] > 0 && params->r[1] > 0) {
- int xq0 = xq[0];
- int xq1 = xq[1];
- for (i = 0; i < height; ++i) {
- for (j = 0; j < width; ++j) {
- const int32_t d = dat[j];
- const int32_t s = src[j];
- const int32_t u = (int32_t)(d << SGRPROJ_RST_BITS);
- int32_t v0 = flt0[j] - u;
- int32_t v1 = flt1[j] - u;
- int32_t v = half;
- v += xq0 * v0;
- v += xq1 * v1;
- const int32_t e =
- (v >> (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS)) + d - s;
- err += e * e;
- }
- dat += dat_stride;
- flt0 += flt0_stride;
- flt1 += flt1_stride;
- src += src_stride;
- }
- } else if (params->r[0] > 0 || params->r[1] > 0) {
- int exq;
- int32_t *flt;
- int flt_stride;
- if (params->r[0] > 0) {
- exq = xq[0];
- flt = flt0;
- flt_stride = flt0_stride;
- } else {
- exq = xq[1];
- flt = flt1;
- flt_stride = flt1_stride;
- }
- for (i = 0; i < height; ++i) {
- for (j = 0; j < width; ++j) {
- const int32_t d = dat[j];
- const int32_t s = src[j];
- const int32_t u = (int32_t)(d << SGRPROJ_RST_BITS);
- int32_t v = half;
- v += exq * (flt[j] - u);
- const int32_t e =
- (v >> (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS)) + d - s;
- err += e * e;
- }
- dat += dat_stride;
- flt += flt_stride;
- src += src_stride;
- }
- } else {
- for (i = 0; i < height; ++i) {
- for (j = 0; j < width; ++j) {
- const int32_t d = dat[j];
- const int32_t s = src[j];
- const int32_t e = d - s;
- err += e * e;
- }
- dat += dat_stride;
- src += src_stride;
- }
- }
+ return av1_highbd_pixel_proj_error(src8, width, height, src_stride, dat8,
+ dat_stride, flt0, flt0_stride, flt1,
+ flt1_stride, xq, params);
}
- return err;
}
#define USE_SGRPROJ_REFINEMENT_SEARCH 1
diff --git a/av1/encoder/x86/pickrst_avx2.c b/av1/encoder/x86/pickrst_avx2.c
index 579e424..7a63c60 100644
--- a/av1/encoder/x86/pickrst_avx2.c
+++ b/av1/encoder/x86/pickrst_avx2.c
@@ -621,3 +621,223 @@
err += sum[0] + sum[1] + sum[2] + sum[3];
return err;
}
+
+int64_t av1_highbd_pixel_proj_error_avx2(
+ const uint8_t *src8, int width, int height, int src_stride,
+ const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride,
+ int32_t *flt1, int flt1_stride, int xq[2], const sgr_params_type *params) {
+ int i, j, k;
+ const int32_t shift = SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS;
+ const __m256i rounding = _mm256_set1_epi32(1 << (shift - 1));
+ __m256i sum64 = _mm256_setzero_si256();
+ const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
+ const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
+ int64_t err = 0;
+ if (params->r[0] > 0 && params->r[1] > 0) { // Both filters are enabled
+ const __m256i xq0 = _mm256_set1_epi32(xq[0]);
+ const __m256i xq1 = _mm256_set1_epi32(xq[1]);
+ for (i = 0; i < height; ++i) {
+ __m256i sum32 = _mm256_setzero_si256();
+ for (j = 0; j <= width - 16; j += 16) { // Process 16 pixels at a time
+ // Load 16 pixels each from source image and corrupted image
+ const __m256i s0 = yy_loadu_256(src + j);
+ const __m256i d0 = yy_loadu_256(dat + j);
+ // s0 = [15 14 13 12 11 10 9 8] [7 6 5 4 3 2 1 0] as u16 (indices)
+
+ // Shift-up each pixel to match filtered image scaling
+ const __m256i u0 = _mm256_slli_epi16(d0, SGRPROJ_RST_BITS);
+
+ // Split u0 into two halves and pad each from u16 to i32
+ const __m256i u0l = _mm256_cvtepu16_epi32(_mm256_castsi256_si128(u0));
+ const __m256i u0h =
+ _mm256_cvtepu16_epi32(_mm256_extracti128_si256(u0, 1));
+ // u0h, u0l = [15 14 13 12] [11 10 9 8], [7 6 5 4] [3 2 1 0] as u32
+
+ // Load 16 pixels from each filtered image
+ const __m256i flt0l = yy_loadu_256(flt0 + j);
+ const __m256i flt0h = yy_loadu_256(flt0 + j + 8);
+ const __m256i flt1l = yy_loadu_256(flt1 + j);
+ const __m256i flt1h = yy_loadu_256(flt1 + j + 8);
+ // flt?l, flt?h = [15 14 13 12] [11 10 9 8], [7 6 5 4] [3 2 1 0] as u32
+
+ // Subtract shifted corrupt image from each filtered image
+ const __m256i flt0l_subu = _mm256_sub_epi32(flt0l, u0l);
+ const __m256i flt0h_subu = _mm256_sub_epi32(flt0h, u0h);
+ const __m256i flt1l_subu = _mm256_sub_epi32(flt1l, u0l);
+ const __m256i flt1h_subu = _mm256_sub_epi32(flt1h, u0h);
+
+ // Multiply basis vectors by appropriate coefficients
+ const __m256i v0l = _mm256_mullo_epi32(flt0l_subu, xq0);
+ const __m256i v0h = _mm256_mullo_epi32(flt0h_subu, xq0);
+ const __m256i v1l = _mm256_mullo_epi32(flt1l_subu, xq1);
+ const __m256i v1h = _mm256_mullo_epi32(flt1h_subu, xq1);
+
+ // Add together the contributions from the two basis vectors
+ const __m256i vl = _mm256_add_epi32(v0l, v1l);
+ const __m256i vh = _mm256_add_epi32(v0h, v1h);
+
+ // Right-shift v with appropriate rounding
+ const __m256i vrl =
+ _mm256_srai_epi32(_mm256_add_epi32(vl, rounding), shift);
+ const __m256i vrh =
+ _mm256_srai_epi32(_mm256_add_epi32(vh, rounding), shift);
+ // vrh, vrl = [15 14 13 12] [11 10 9 8], [7 6 5 4] [3 2 1 0]
+
+ // Saturate each i32 to an i16 then combine both halves
+ // The permute (control=[3 1 2 0]) fixes weird ordering from AVX lanes
+ const __m256i vr =
+ _mm256_permute4x64_epi64(_mm256_packs_epi32(vrl, vrh), 0xd8);
+ // intermediate = [15 14 13 12 7 6 5 4] [11 10 9 8 3 2 1 0]
+ // vr = [15 14 13 12 11 10 9 8] [7 6 5 4 3 2 1 0]
+
+ // Add twin-subspace-sgr-filter to corrupt image then subtract source
+ const __m256i e0 = _mm256_sub_epi16(_mm256_add_epi16(vr, d0), s0);
+
+ // Calculate squared error and add adjacent values
+ const __m256i err0 = _mm256_madd_epi16(e0, e0);
+
+ sum32 = _mm256_add_epi32(sum32, err0);
+ }
+
+ const __m256i sum32l =
+ _mm256_cvtepu32_epi64(_mm256_castsi256_si128(sum32));
+ sum64 = _mm256_add_epi64(sum64, sum32l);
+ const __m256i sum32h =
+ _mm256_cvtepu32_epi64(_mm256_extracti128_si256(sum32, 1));
+ sum64 = _mm256_add_epi64(sum64, sum32h);
+
+ // Process remaining pixels in this row (modulo 16)
+ for (k = j; k < width; ++k) {
+ const int32_t u = (int32_t)(dat[k] << SGRPROJ_RST_BITS);
+ int32_t v = xq[0] * (flt0[k] - u) + xq[1] * (flt1[k] - u);
+ const int32_t e = ROUND_POWER_OF_TWO(v, shift) + dat[k] - src[k];
+ err += e * e;
+ }
+ dat += dat_stride;
+ src += src_stride;
+ flt0 += flt0_stride;
+ flt1 += flt1_stride;
+ }
+ } else if (params->r[0] > 0 || params->r[1] > 0) { // Only one filter enabled
+ const int32_t xq_on = (params->r[0] > 0) ? xq[0] : xq[1];
+ const __m256i xq_active = _mm256_set1_epi32(xq_on);
+ const __m256i xq_inactive =
+ _mm256_set1_epi32(-xq_on * (1 << SGRPROJ_RST_BITS));
+ const int32_t *flt = (params->r[0] > 0) ? flt0 : flt1;
+ const int flt_stride = (params->r[0] > 0) ? flt0_stride : flt1_stride;
+ for (i = 0; i < height; ++i) {
+ __m256i sum32 = _mm256_setzero_si256();
+ for (j = 0; j <= width - 16; j += 16) {
+ // Load 16 pixels from source image
+ const __m256i s0 = yy_loadu_256(src + j);
+ // s0 = [15 14 13 12 11 10 9 8] [7 6 5 4 3 2 1 0] as u16
+
+ // Load 16 pixels from corrupted image and pad each u16 to i32
+ const __m256i d0 = yy_loadu_256(dat + j);
+ const __m256i d0h =
+ _mm256_cvtepu16_epi32(_mm256_extracti128_si256(d0, 1));
+ const __m256i d0l = _mm256_cvtepu16_epi32(_mm256_castsi256_si128(d0));
+ // d0 = [15 14 13 12 11 10 9 8] [7 6 5 4 3 2 1 0] as u16
+ // d0h, d0l = [15 14 13 12] [11 10 9 8], [7 6 5 4] [3 2 1 0] as i32
+
+ // Load 16 pixels from the filtered image
+ const __m256i flth = yy_loadu_256(flt + j + 8);
+ const __m256i fltl = yy_loadu_256(flt + j);
+ // flth, fltl = [15 14 13 12] [11 10 9 8], [7 6 5 4] [3 2 1 0] as i32
+
+ const __m256i flth_xq = _mm256_mullo_epi32(flth, xq_active);
+ const __m256i fltl_xq = _mm256_mullo_epi32(fltl, xq_active);
+ const __m256i d0h_xq = _mm256_mullo_epi32(d0h, xq_inactive);
+ const __m256i d0l_xq = _mm256_mullo_epi32(d0l, xq_inactive);
+
+ const __m256i vh = _mm256_add_epi32(flth_xq, d0h_xq);
+ const __m256i vl = _mm256_add_epi32(fltl_xq, d0l_xq);
+
+ // Shift this down with appropriate rounding
+ const __m256i vrh =
+ _mm256_srai_epi32(_mm256_add_epi32(vh, rounding), shift);
+ const __m256i vrl =
+ _mm256_srai_epi32(_mm256_add_epi32(vl, rounding), shift);
+ // vrh, vrl = [15 14 13 12] [11 10 9 8], [7 6 5 4] [3 2 1 0] as i32
+
+ // Saturate each i32 to an i16 then combine both halves
+ // The permute (control=[3 1 2 0]) fixes weird ordering from AVX lanes
+ const __m256i vr =
+ _mm256_permute4x64_epi64(_mm256_packs_epi32(vrl, vrh), 0xd8);
+ // intermediate = [15 14 13 12 7 6 5 4] [11 10 9 8 3 2 1 0] as u16
+ // vr = [15 14 13 12 11 10 9 8] [7 6 5 4 3 2 1 0] as u16
+
+ // Subtract twin-subspace-sgr filtered from source image to get error
+ const __m256i e0 = _mm256_sub_epi16(_mm256_add_epi16(vr, d0), s0);
+
+ // Calculate squared error and add adjacent values
+ const __m256i err0 = _mm256_madd_epi16(e0, e0);
+
+ sum32 = _mm256_add_epi32(sum32, err0);
+ }
+
+ const __m256i sum32l =
+ _mm256_cvtepu32_epi64(_mm256_castsi256_si128(sum32));
+ sum64 = _mm256_add_epi64(sum64, sum32l);
+ const __m256i sum32h =
+ _mm256_cvtepu32_epi64(_mm256_extracti128_si256(sum32, 1));
+ sum64 = _mm256_add_epi64(sum64, sum32h);
+
+ // Process remaining pixels in this row (modulo 16)
+ for (k = j; k < width; ++k) {
+ const int32_t u = (int32_t)(dat[k] << SGRPROJ_RST_BITS);
+ int32_t v = xq_on * (flt[k] - u);
+ const int32_t e = ROUND_POWER_OF_TWO(v, shift) + dat[k] - src[k];
+ err += e * e;
+ }
+ dat += dat_stride;
+ src += src_stride;
+ flt += flt_stride;
+ }
+ } else { // Neither filter is enabled
+ for (i = 0; i < height; ++i) {
+ __m256i sum32 = _mm256_setzero_si256();
+ for (j = 0; j <= width - 32; j += 32) {
+ // Load 2x16 u16 from source image
+ const __m256i s0l = yy_loadu_256(src + j);
+ const __m256i s0h = yy_loadu_256(src + j + 16);
+
+ // Load 2x16 u16 from corrupted image
+ const __m256i d0l = yy_loadu_256(dat + j);
+ const __m256i d0h = yy_loadu_256(dat + j + 16);
+
+ // Subtract corrupted image from source image
+ const __m256i diffl = _mm256_sub_epi16(d0l, s0l);
+ const __m256i diffh = _mm256_sub_epi16(d0h, s0h);
+
+ // Square error and add adjacent values
+ const __m256i err0l = _mm256_madd_epi16(diffl, diffl);
+ const __m256i err0h = _mm256_madd_epi16(diffh, diffh);
+
+ sum32 = _mm256_add_epi32(sum32, err0l);
+ sum32 = _mm256_add_epi32(sum32, err0h);
+ }
+
+ const __m256i sum32l =
+ _mm256_cvtepu32_epi64(_mm256_castsi256_si128(sum32));
+ sum64 = _mm256_add_epi64(sum64, sum32l);
+ const __m256i sum32h =
+ _mm256_cvtepu32_epi64(_mm256_extracti128_si256(sum32, 1));
+ sum64 = _mm256_add_epi64(sum64, sum32h);
+
+ // Process remaining pixels (modulu 16)
+ for (k = j; k < width; ++k) {
+ const int32_t e = (int32_t)(dat[k]) - src[k];
+ err += e * e;
+ }
+ dat += dat_stride;
+ src += src_stride;
+ }
+ }
+
+ // Sum 4 values from sum64l and sum64h into err
+ int64_t sum[4];
+ yy_storeu_256(sum, sum64);
+ err += sum[0] + sum[1] + sum[2] + sum[3];
+ return err;
+}
diff --git a/av1/encoder/x86/pickrst_sse4.c b/av1/encoder/x86/pickrst_sse4.c
index a067ab6..2326736 100644
--- a/av1/encoder/x86/pickrst_sse4.c
+++ b/av1/encoder/x86/pickrst_sse4.c
@@ -621,3 +621,209 @@
err += sum[0] + sum[1];
return err;
}
+
+int64_t av1_highbd_pixel_proj_error_sse4_1(
+ const uint8_t *src8, int width, int height, int src_stride,
+ const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride,
+ int32_t *flt1, int flt1_stride, int xq[2], const sgr_params_type *params) {
+ int i, j, k;
+ const int32_t shift = SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS;
+ const __m128i rounding = _mm_set1_epi32(1 << (shift - 1));
+ __m128i sum64 = _mm_setzero_si128();
+ const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
+ const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
+ int64_t err = 0;
+ if (params->r[0] > 0 && params->r[1] > 0) { // Both filters are enabled
+ const __m128i xq0 = _mm_set1_epi32(xq[0]);
+ const __m128i xq1 = _mm_set1_epi32(xq[1]);
+
+ for (i = 0; i < height; ++i) {
+ __m128i sum32 = _mm_setzero_si128();
+ for (j = 0; j <= width - 8; j += 8) {
+ // Load 8x pixels from source image
+ const __m128i s0 = xx_loadu_128(src + j);
+ // s0 = [7 6 5 4 3 2 1 0] as i16 (indices of src[])
+
+ // Load 8x pixels from corrupted image
+ const __m128i d0 = xx_loadu_128(dat + j);
+ // d0 = [7 6 5 4 3 2 1 0] as i16 (indices of dat[])
+
+ // Shift each pixel value up by SGRPROJ_RST_BITS
+ const __m128i u0 = _mm_slli_epi16(d0, SGRPROJ_RST_BITS);
+
+ // Split u0 into two halves and pad each from u16 to i32
+ const __m128i u0l = _mm_cvtepu16_epi32(u0);
+ const __m128i u0h = _mm_cvtepu16_epi32(_mm_srli_si128(u0, 8));
+ // u0h = [7 6 5 4] as i32, u0l = [3 2 1 0] as i32, all dat[] indices
+
+ // Load 8 pixels from first and second filtered images
+ const __m128i flt0l = xx_loadu_128(flt0 + j);
+ const __m128i flt0h = xx_loadu_128(flt0 + j + 4);
+ const __m128i flt1l = xx_loadu_128(flt1 + j);
+ const __m128i flt1h = xx_loadu_128(flt1 + j + 4);
+ // flt0 = [7 6 5 4] [3 2 1 0] as i32 (indices of flt0+j)
+ // flt1 = [7 6 5 4] [3 2 1 0] as i32 (indices of flt1+j)
+
+ // Subtract shifted corrupt image from each filtered image
+ // This gives our two basis vectors for the projection
+ const __m128i flt0l_subu = _mm_sub_epi32(flt0l, u0l);
+ const __m128i flt0h_subu = _mm_sub_epi32(flt0h, u0h);
+ const __m128i flt1l_subu = _mm_sub_epi32(flt1l, u0l);
+ const __m128i flt1h_subu = _mm_sub_epi32(flt1h, u0h);
+ // flt?h_subu = [ f[7]-u[7] f[6]-u[6] f[5]-u[5] f[4]-u[4] ] as i32
+ // flt?l_subu = [ f[3]-u[3] f[2]-u[2] f[1]-u[1] f[0]-u[0] ] as i32
+
+ // Multiply each basis vector by the corresponding coefficient
+ const __m128i v0l = _mm_mullo_epi32(flt0l_subu, xq0);
+ const __m128i v0h = _mm_mullo_epi32(flt0h_subu, xq0);
+ const __m128i v1l = _mm_mullo_epi32(flt1l_subu, xq1);
+ const __m128i v1h = _mm_mullo_epi32(flt1h_subu, xq1);
+
+ // Add together the contribution from each scaled basis vector
+ const __m128i vl = _mm_add_epi32(v0l, v1l);
+ const __m128i vh = _mm_add_epi32(v0h, v1h);
+
+ // Right-shift v with appropriate rounding
+ const __m128i vrl = _mm_srai_epi32(_mm_add_epi32(vl, rounding), shift);
+ const __m128i vrh = _mm_srai_epi32(_mm_add_epi32(vh, rounding), shift);
+
+ // Saturate each i32 value to i16 and combine lower and upper halves
+ const __m128i vr = _mm_packs_epi32(vrl, vrh);
+
+ // Add twin-subspace-sgr-filter to corrupt image then subtract source
+ const __m128i e0 = _mm_sub_epi16(_mm_add_epi16(vr, d0), s0);
+
+ // Calculate squared error and add adjacent values
+ const __m128i err0 = _mm_madd_epi16(e0, e0);
+
+ sum32 = _mm_add_epi32(sum32, err0);
+ }
+
+ const __m128i sum32l = _mm_cvtepu32_epi64(sum32);
+ sum64 = _mm_add_epi64(sum64, sum32l);
+ const __m128i sum32h = _mm_cvtepu32_epi64(_mm_srli_si128(sum32, 8));
+ sum64 = _mm_add_epi64(sum64, sum32h);
+
+ // Process remaining pixels in this row (modulo 8)
+ for (k = j; k < width; ++k) {
+ const int32_t u = (int32_t)(dat[k] << SGRPROJ_RST_BITS);
+ int32_t v = xq[0] * (flt0[k] - u) + xq[1] * (flt1[k] - u);
+ const int32_t e = ROUND_POWER_OF_TWO(v, shift) + dat[k] - src[k];
+ err += e * e;
+ }
+ dat += dat_stride;
+ src += src_stride;
+ flt0 += flt0_stride;
+ flt1 += flt1_stride;
+ }
+ } else if (params->r[0] > 0 || params->r[1] > 0) { // Only one filter enabled
+ const int32_t xq_on = (params->r[0] > 0) ? xq[0] : xq[1];
+ const __m128i xq_active = _mm_set1_epi32(xq_on);
+ const __m128i xq_inactive =
+ _mm_set1_epi32(-xq_on * (1 << SGRPROJ_RST_BITS));
+ const int32_t *flt = (params->r[0] > 0) ? flt0 : flt1;
+ const int flt_stride = (params->r[0] > 0) ? flt0_stride : flt1_stride;
+ for (i = 0; i < height; ++i) {
+ __m128i sum32 = _mm_setzero_si128();
+ for (j = 0; j <= width - 8; j += 8) {
+ // Load 8x pixels from source image
+ const __m128i s0 = xx_loadu_128(src + j);
+ // s0 = [7 6 5 4 3 2 1 0] as u16 (indices of src[])
+
+ // Load 8x pixels from corrupted image and pad each u16 to i32
+ const __m128i d0 = xx_loadu_128(dat + j);
+ const __m128i d0h = _mm_cvtepu16_epi32(_mm_srli_si128(d0, 8));
+ const __m128i d0l = _mm_cvtepu16_epi32(d0);
+ // d0h, d0l = [7 6 5 4], [3 2 1 0] as u32 (indices of dat[])
+
+ // Load 8 pixels from the filtered image
+ const __m128i flth = xx_loadu_128(flt + j + 4);
+ const __m128i fltl = xx_loadu_128(flt + j);
+ // flth, fltl = [7 6 5 4], [3 2 1 0] as i32 (indices of flt+j)
+
+ const __m128i flth_xq = _mm_mullo_epi32(flth, xq_active);
+ const __m128i fltl_xq = _mm_mullo_epi32(fltl, xq_active);
+ const __m128i d0h_xq = _mm_mullo_epi32(d0h, xq_inactive);
+ const __m128i d0l_xq = _mm_mullo_epi32(d0l, xq_inactive);
+
+ const __m128i vh = _mm_add_epi32(flth_xq, d0h_xq);
+ const __m128i vl = _mm_add_epi32(fltl_xq, d0l_xq);
+ // vh = [ xq0(f[7]-d[7]) xq0(f[6]-d[6]) xq0(f[5]-d[5]) xq0(f[4]-d[4]) ]
+ // vl = [ xq0(f[3]-d[3]) xq0(f[2]-d[2]) xq0(f[1]-d[1]) xq0(f[0]-d[0]) ]
+
+ // Shift this down with appropriate rounding
+ const __m128i vrh = _mm_srai_epi32(_mm_add_epi32(vh, rounding), shift);
+ const __m128i vrl = _mm_srai_epi32(_mm_add_epi32(vl, rounding), shift);
+
+ // Saturate vr0 and vr1 from i32 to i16 then pack together
+ const __m128i vr = _mm_packs_epi32(vrl, vrh);
+
+ // Subtract twin-subspace-sgr filtered from source image to get error
+ const __m128i e0 = _mm_sub_epi16(_mm_add_epi16(vr, d0), s0);
+
+ // Calculate squared error and add adjacent values
+ const __m128i err0 = _mm_madd_epi16(e0, e0);
+
+ sum32 = _mm_add_epi32(sum32, err0);
+ }
+
+ const __m128i sum32l = _mm_cvtepu32_epi64(sum32);
+ sum64 = _mm_add_epi64(sum64, sum32l);
+ const __m128i sum32h = _mm_cvtepu32_epi64(_mm_srli_si128(sum32, 8));
+ sum64 = _mm_add_epi64(sum64, sum32h);
+
+ // Process remaining pixels in this row (modulo 8)
+ for (k = j; k < width; ++k) {
+ const int32_t u = (int32_t)(dat[k] << SGRPROJ_RST_BITS);
+ int32_t v = xq_on * (flt[k] - u);
+ const int32_t e = ROUND_POWER_OF_TWO(v, shift) + dat[k] - src[k];
+ err += e * e;
+ }
+ dat += dat_stride;
+ src += src_stride;
+ flt += flt_stride;
+ }
+ } else { // Neither filter is enabled
+ for (i = 0; i < height; ++i) {
+ __m128i sum32 = _mm_setzero_si128();
+ for (j = 0; j <= width - 16; j += 16) {
+ // Load 2x8 u16 from source image
+ const __m128i s0 = xx_loadu_128(src + j);
+ const __m128i s1 = xx_loadu_128(src + j + 8);
+ // Load 2x8 u16 from corrupted image
+ const __m128i d0 = xx_loadu_128(dat + j);
+ const __m128i d1 = xx_loadu_128(dat + j + 8);
+
+ // Subtract corrupted image from source image
+ const __m128i diff0 = _mm_sub_epi16(d0, s0);
+ const __m128i diff1 = _mm_sub_epi16(d1, s1);
+
+ // Square error and add adjacent values
+ const __m128i err0 = _mm_madd_epi16(diff0, diff0);
+ const __m128i err1 = _mm_madd_epi16(diff1, diff1);
+
+ sum32 = _mm_add_epi32(sum32, err0);
+ sum32 = _mm_add_epi32(sum32, err1);
+ }
+
+ const __m128i sum32l = _mm_cvtepu32_epi64(sum32);
+ sum64 = _mm_add_epi64(sum64, sum32l);
+ const __m128i sum32h = _mm_cvtepu32_epi64(_mm_srli_si128(sum32, 8));
+ sum64 = _mm_add_epi64(sum64, sum32h);
+
+ // Process remaining pixels (modulu 8)
+ for (k = j; k < width; ++k) {
+ const int32_t e = (int32_t)(dat[k]) - src[k];
+ err += e * e;
+ }
+ dat += dat_stride;
+ src += src_stride;
+ }
+ }
+
+ // Sum 4 values from sum64l and sum64h into err
+ int64_t sum[2];
+ xx_storeu_128(sum, sum64);
+ err += sum[0] + sum[1];
+ return err;
+}
diff --git a/test/pickrst_test.cc b/test/pickrst_test.cc
index 040e8e8..68b6621 100644
--- a/test/pickrst_test.cc
+++ b/test/pickrst_test.cc
@@ -23,7 +23,7 @@
#define MAX_DATA_BLOCK 384
-namespace {
+namespace pickrst_test_lowbd {
static const int kIterations = 100;
typedef int64_t (*lowbd_pixel_proj_error_func)(
@@ -184,4 +184,169 @@
::testing::Values(av1_lowbd_pixel_proj_error_avx2));
#endif // HAVE_AVX2
-} // namespace
+} // namespace pickrst_test_lowbd
+
+namespace pickrst_test_highbd {
+static const int kIterations = 100;
+
+typedef int64_t (*highbd_pixel_proj_error_func)(
+ const uint8_t *src8, int width, int height, int src_stride,
+ const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride,
+ int32_t *flt1, int flt1_stride, int xq[2], const sgr_params_type *params);
+
+typedef libaom_test::FuncParam<highbd_pixel_proj_error_func> TestFuncs;
+
+////////////////////////////////////////////////////////////////////////////////
+// High bit-depth
+////////////////////////////////////////////////////////////////////////////////
+
+typedef ::testing::tuple<const highbd_pixel_proj_error_func>
+ PixelProjErrorTestParam;
+
+class PixelProjHighbdErrorTest
+ : public ::testing::TestWithParam<PixelProjErrorTestParam> {
+ public:
+ virtual void SetUp() {
+ target_func_ = GET_PARAM(0);
+ src_ =
+ (uint16_t *)aom_malloc(MAX_DATA_BLOCK * MAX_DATA_BLOCK * sizeof(*src_));
+ dgd_ =
+ (uint16_t *)aom_malloc(MAX_DATA_BLOCK * MAX_DATA_BLOCK * sizeof(*dgd_));
+ flt0_ =
+ (int32_t *)aom_malloc(MAX_DATA_BLOCK * MAX_DATA_BLOCK * sizeof(*flt0_));
+ flt1_ =
+ (int32_t *)aom_malloc(MAX_DATA_BLOCK * MAX_DATA_BLOCK * sizeof(*flt1_));
+ }
+ virtual void TearDown() {
+ aom_free(src_);
+ aom_free(dgd_);
+ aom_free(flt0_);
+ aom_free(flt1_);
+ }
+ void runPixelProjErrorTest(int32_t run_times);
+ void runPixelProjErrorTest_ExtremeValues();
+
+ private:
+ highbd_pixel_proj_error_func target_func_;
+ ACMRandom rng_;
+ uint16_t *src_;
+ uint16_t *dgd_;
+ int32_t *flt0_;
+ int32_t *flt1_;
+};
+
+void PixelProjHighbdErrorTest::runPixelProjErrorTest(int32_t run_times) {
+ int h_end = run_times != 1 ? 128 : (rng_.Rand16() % MAX_DATA_BLOCK) + 1;
+ int v_end = run_times != 1 ? 128 : (rng_.Rand16() % MAX_DATA_BLOCK) + 1;
+ const int dgd_stride = MAX_DATA_BLOCK;
+ const int src_stride = MAX_DATA_BLOCK;
+ const int flt0_stride = MAX_DATA_BLOCK;
+ const int flt1_stride = MAX_DATA_BLOCK;
+ sgr_params_type params;
+ int xq[2];
+ const int iters = run_times == 1 ? kIterations : 4;
+ for (int iter = 0; iter < iters && !HasFatalFailure(); ++iter) {
+ int64_t err_ref = 0, err_test = 1;
+ for (int i = 0; i < MAX_DATA_BLOCK * MAX_DATA_BLOCK; ++i) {
+ dgd_[i] = rng_.Rand16() % (1 << 12);
+ src_[i] = rng_.Rand16() % (1 << 12);
+ flt0_[i] = rng_.Rand15Signed();
+ flt1_[i] = rng_.Rand15Signed();
+ }
+ xq[0] = rng_.Rand8() % (1 << SGRPROJ_PRJ_BITS);
+ xq[1] = rng_.Rand8() % (1 << SGRPROJ_PRJ_BITS);
+ params.r[0] = run_times == 1 ? (rng_.Rand8() % MAX_RADIUS) : (iter % 2);
+ params.r[1] = run_times == 1 ? (rng_.Rand8() % MAX_RADIUS) : (iter / 2);
+ params.s[0] = run_times == 1 ? (rng_.Rand8() % MAX_RADIUS) : (iter % 2);
+ params.s[1] = run_times == 1 ? (rng_.Rand8() % MAX_RADIUS) : (iter / 2);
+ uint8_t *dgd8 = CONVERT_TO_BYTEPTR(dgd_);
+ uint8_t *src8 = CONVERT_TO_BYTEPTR(src_);
+
+ aom_usec_timer timer;
+ aom_usec_timer_start(&timer);
+ for (int i = 0; i < run_times; ++i) {
+ err_ref = av1_highbd_pixel_proj_error_c(
+ src8, h_end, v_end, src_stride, dgd8, dgd_stride, flt0_, flt0_stride,
+ flt1_, flt1_stride, xq, ¶ms);
+ }
+ aom_usec_timer_mark(&timer);
+ const double time1 = static_cast<double>(aom_usec_timer_elapsed(&timer));
+ aom_usec_timer_start(&timer);
+ for (int i = 0; i < run_times; ++i) {
+ err_test =
+ target_func_(src8, h_end, v_end, src_stride, dgd8, dgd_stride, flt0_,
+ flt0_stride, flt1_, flt1_stride, xq, ¶ms);
+ }
+ aom_usec_timer_mark(&timer);
+ const double time2 = static_cast<double>(aom_usec_timer_elapsed(&timer));
+ if (run_times > 10) {
+ printf("r0 %d r1 %d %3dx%-3d:%7.2f/%7.2fns (%3.2f)\n", params.r[0],
+ params.r[1], h_end, v_end, time1, time2, time1 / time2);
+ }
+ ASSERT_EQ(err_ref, err_test);
+ }
+}
+
+void PixelProjHighbdErrorTest::runPixelProjErrorTest_ExtremeValues() {
+ const int h_start = 0;
+ int h_end = 192;
+ const int v_start = 0;
+ int v_end = 192;
+ const int dgd_stride = MAX_DATA_BLOCK;
+ const int src_stride = MAX_DATA_BLOCK;
+ const int flt0_stride = MAX_DATA_BLOCK;
+ const int flt1_stride = MAX_DATA_BLOCK;
+ sgr_params_type params;
+ int xq[2];
+ const int iters = kIterations;
+ for (int iter = 0; iter < iters && !HasFatalFailure(); ++iter) {
+ int64_t err_ref = 0, err_test = 1;
+ for (int i = 0; i < MAX_DATA_BLOCK * MAX_DATA_BLOCK; ++i) {
+ dgd_[i] = 0;
+ src_[i] = (1 << 12) - 1;
+ flt0_[i] = rng_.Rand15Signed();
+ flt1_[i] = rng_.Rand15Signed();
+ }
+ xq[0] = rng_.Rand8() % (1 << SGRPROJ_PRJ_BITS);
+ xq[1] = rng_.Rand8() % (1 << SGRPROJ_PRJ_BITS);
+ params.r[0] = rng_.Rand8() % MAX_RADIUS;
+ params.r[1] = rng_.Rand8() % MAX_RADIUS;
+ params.s[0] = rng_.Rand8() % MAX_RADIUS;
+ params.s[1] = rng_.Rand8() % MAX_RADIUS;
+ uint8_t *dgd8 = CONVERT_TO_BYTEPTR(dgd_);
+ uint8_t *src8 = CONVERT_TO_BYTEPTR(src_);
+
+ err_ref = av1_highbd_pixel_proj_error_c(
+ src8, h_end - h_start, v_end - v_start, src_stride, dgd8, dgd_stride,
+ flt0_, flt0_stride, flt1_, flt1_stride, xq, ¶ms);
+
+ err_test = target_func_(src8, h_end - h_start, v_end - v_start, src_stride,
+ dgd8, dgd_stride, flt0_, flt0_stride, flt1_,
+ flt1_stride, xq, ¶ms);
+
+ ASSERT_EQ(err_ref, err_test);
+ }
+}
+
+TEST_P(PixelProjHighbdErrorTest, RandomValues) { runPixelProjErrorTest(1); }
+
+TEST_P(PixelProjHighbdErrorTest, ExtremeValues) {
+ runPixelProjErrorTest_ExtremeValues();
+}
+
+TEST_P(PixelProjHighbdErrorTest, DISABLED_Speed) {
+ runPixelProjErrorTest(200000);
+}
+
+#if HAVE_SSE4_1
+INSTANTIATE_TEST_CASE_P(SSE4_1, PixelProjHighbdErrorTest,
+ ::testing::Values(av1_highbd_pixel_proj_error_sse4_1));
+#endif // HAVE_SSE4_1
+
+#if HAVE_AVX2
+
+INSTANTIATE_TEST_CASE_P(AVX2, PixelProjHighbdErrorTest,
+ ::testing::Values(av1_highbd_pixel_proj_error_avx2));
+#endif // HAVE_AVX2
+
+} // namespace pickrst_test_highbd