Add SSE2 variant for mse_wxh_16bit_highbd in cdef
Added SSE2 variant for mse_wxh_16bit_highbd function
and unit test (MseHBDWxHTest).
Module level gains:
BLOCKSIZE Gain w.r.t. C
8x8 3.3x
8x4 3.0x
4x8 3.3x
4x4 3.2x
Change-Id: I9ad6851588e212a5417a0ab95d8f31e8656df1c1
diff --git a/aom_dsp/aom_dsp_rtcd_defs.pl b/aom_dsp/aom_dsp_rtcd_defs.pl
index ae25b00..2cef014 100755
--- a/aom_dsp/aom_dsp_rtcd_defs.pl
+++ b/aom_dsp/aom_dsp_rtcd_defs.pl
@@ -1132,9 +1132,6 @@
add_proto qw/uint64_t/, "aom_mse_wxh_16bit", "uint8_t *dst, int dstride,uint16_t *src, int sstride, int w, int h";
specialize qw/aom_mse_wxh_16bit sse2 avx2/;
- add_proto qw/uint64_t/, "aom_mse_wxh_16bit_highbd", "uint16_t *dst, int dstride,uint16_t *src, int sstride, int w, int h";
- specialize qw/aom_mse_wxh_16bit_highbd avx2/;
-
foreach (@block_sizes) {
($w, $h) = @$_;
add_proto qw/unsigned int/, "aom_variance${w}x${h}", "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int ref_stride, unsigned int *sse";
@@ -1528,6 +1525,9 @@
add_proto qw/void aom_highbd_dist_wtd_comp_avg_pred/, "uint8_t *comp_pred8, const uint8_t *pred8, int width, int height, const uint8_t *ref8, int ref_stride, const DIST_WTD_COMP_PARAMS *jcp_param";
specialize qw/aom_highbd_dist_wtd_comp_avg_pred sse2/;
+
+ add_proto qw/uint64_t/, "aom_mse_wxh_16bit_highbd", "uint16_t *dst, int dstride,uint16_t *src, int sstride, int w, int h";
+ specialize qw/aom_mse_wxh_16bit_highbd sse2 avx2/;
}
#
# Subpixel Variance
diff --git a/aom_dsp/x86/highbd_variance_sse2.c b/aom_dsp/x86/highbd_variance_sse2.c
index b7d15f9..d1bd7d4 100644
--- a/aom_dsp/x86/highbd_variance_sse2.c
+++ b/aom_dsp/x86/highbd_variance_sse2.c
@@ -840,3 +840,100 @@
pred += 8;
}
}
+
+uint64_t aom_mse_4xh_16bit_highbd_sse2(uint16_t *dst, int dstride,
+ uint16_t *src, int sstride, int h) {
+ uint64_t sum = 0;
+ __m128i reg0_4x16, reg1_4x16;
+ __m128i src_8x16;
+ __m128i dst_8x16;
+ __m128i res0_4x32, res1_4x32, res0_4x64, res1_4x64, res2_4x64, res3_4x64;
+ __m128i sub_result_8x16;
+ const __m128i zeros = _mm_setzero_si128();
+ __m128i square_result = _mm_setzero_si128();
+ for (int i = 0; i < h; i += 2) {
+ reg0_4x16 = _mm_loadl_epi64((__m128i const *)(&dst[(i + 0) * dstride]));
+ reg1_4x16 = _mm_loadl_epi64((__m128i const *)(&dst[(i + 1) * dstride]));
+ dst_8x16 = _mm_unpacklo_epi64(reg0_4x16, reg1_4x16);
+
+ reg0_4x16 = _mm_loadl_epi64((__m128i const *)(&src[(i + 0) * sstride]));
+ reg1_4x16 = _mm_loadl_epi64((__m128i const *)(&src[(i + 1) * sstride]));
+ src_8x16 = _mm_unpacklo_epi64(reg0_4x16, reg1_4x16);
+
+ sub_result_8x16 = _mm_sub_epi16(src_8x16, dst_8x16);
+
+ res0_4x32 = _mm_unpacklo_epi16(sub_result_8x16, zeros);
+ res1_4x32 = _mm_unpackhi_epi16(sub_result_8x16, zeros);
+
+ res0_4x32 = _mm_madd_epi16(res0_4x32, res0_4x32);
+ res1_4x32 = _mm_madd_epi16(res1_4x32, res1_4x32);
+
+ res0_4x64 = _mm_unpacklo_epi32(res0_4x32, zeros);
+ res1_4x64 = _mm_unpackhi_epi32(res0_4x32, zeros);
+ res2_4x64 = _mm_unpacklo_epi32(res1_4x32, zeros);
+ res3_4x64 = _mm_unpackhi_epi32(res1_4x32, zeros);
+
+ square_result = _mm_add_epi64(
+ square_result,
+ _mm_add_epi64(
+ _mm_add_epi64(_mm_add_epi64(res0_4x64, res1_4x64), res2_4x64),
+ res3_4x64));
+ }
+
+ const __m128i sum_1x64 =
+ _mm_add_epi64(square_result, _mm_srli_si128(square_result, 8));
+ xx_storel_64(&sum, sum_1x64);
+ return sum;
+}
+
+uint64_t aom_mse_8xh_16bit_highbd_sse2(uint16_t *dst, int dstride,
+ uint16_t *src, int sstride, int h) {
+ uint64_t sum = 0;
+ __m128i src_8x16;
+ __m128i dst_8x16;
+ __m128i res0_4x32, res1_4x32, res0_4x64, res1_4x64, res2_4x64, res3_4x64;
+ __m128i sub_result_8x16;
+ const __m128i zeros = _mm_setzero_si128();
+ __m128i square_result = _mm_setzero_si128();
+
+ for (int i = 0; i < h; i++) {
+ dst_8x16 = _mm_loadu_si128((__m128i *)&dst[i * dstride]);
+ src_8x16 = _mm_loadu_si128((__m128i *)&src[i * sstride]);
+
+ sub_result_8x16 = _mm_sub_epi16(src_8x16, dst_8x16);
+
+ res0_4x32 = _mm_unpacklo_epi16(sub_result_8x16, zeros);
+ res1_4x32 = _mm_unpackhi_epi16(sub_result_8x16, zeros);
+
+ res0_4x32 = _mm_madd_epi16(res0_4x32, res0_4x32);
+ res1_4x32 = _mm_madd_epi16(res1_4x32, res1_4x32);
+
+ res0_4x64 = _mm_unpacklo_epi32(res0_4x32, zeros);
+ res1_4x64 = _mm_unpackhi_epi32(res0_4x32, zeros);
+ res2_4x64 = _mm_unpacklo_epi32(res1_4x32, zeros);
+ res3_4x64 = _mm_unpackhi_epi32(res1_4x32, zeros);
+
+ square_result = _mm_add_epi64(
+ square_result,
+ _mm_add_epi64(
+ _mm_add_epi64(_mm_add_epi64(res0_4x64, res1_4x64), res2_4x64),
+ res3_4x64));
+ }
+
+ const __m128i sum_1x64 =
+ _mm_add_epi64(square_result, _mm_srli_si128(square_result, 8));
+ xx_storel_64(&sum, sum_1x64);
+ return sum;
+}
+
+uint64_t aom_mse_wxh_16bit_highbd_sse2(uint16_t *dst, int dstride,
+ uint16_t *src, int sstride, int w,
+ int h) {
+ assert((w == 8 || w == 4) && (h == 8 || h == 4) &&
+ "w=8/4 and h=8/4 must satisfy");
+ switch (w) {
+ case 4: return aom_mse_4xh_16bit_highbd_sse2(dst, dstride, src, sstride, h);
+ case 8: return aom_mse_8xh_16bit_highbd_sse2(dst, dstride, src, sstride, h);
+ default: assert(0 && "unsupported width"); return -1;
+ }
+}
diff --git a/av1/encoder/pickcdef.c b/av1/encoder/pickcdef.c
index 1d9a315..ac293bc 100644
--- a/av1/encoder/pickcdef.c
+++ b/av1/encoder/pickcdef.c
@@ -226,6 +226,7 @@
BLOCK_SIZE bsize, int coeff_shift,
int row, int col);
+#if CONFIG_AV1_HIGHBITDEPTH
static void copy_sb16_16_highbd(uint16_t *dst, int dstride, const void *src,
int src_voffset, int src_hoffset, int sstride,
int vsize, int hsize) {
@@ -235,6 +236,7 @@
for (r = 0; r < vsize; r++)
memcpy(dst + r * dstride, base + r * sstride, hsize * sizeof(*base));
}
+#endif
static void copy_sb16_16(uint16_t *dst, int dstride, const void *src,
int src_voffset, int src_hoffset, int sstride,
@@ -256,7 +258,7 @@
*width_log2 = MI_SIZE_LOG2 + mi_size_wide_log2[bsize];
*height_log2 = MI_SIZE_LOG2 + mi_size_wide_log2[bsize];
}
-
+#if CONFIG_AV1_HIGHBITDEPTH
/* Compute MSE only on the blocks we filtered. */
static uint64_t compute_cdef_dist_highbd(void *dst, int dstride, uint16_t *src,
cdef_list *dlist, int cdef_count,
@@ -280,7 +282,7 @@
}
return sum >> 2 * coeff_shift;
}
-
+#endif
static uint64_t compute_cdef_dist(void *dst, int dstride, uint16_t *src,
cdef_list *dlist, int cdef_count,
BLOCK_SIZE bsize, int coeff_shift, int row,
@@ -422,7 +424,7 @@
copy_fn_t copy_fn;
compute_cdef_dist_t compute_cdef_dist_fn;
-
+#if CONFIG_AV1_HIGHBITDEPTH
if (cm->seq_params.use_highbitdepth) {
copy_fn = copy_sb16_16_highbd;
compute_cdef_dist_fn = compute_cdef_dist_highbd;
@@ -430,6 +432,10 @@
copy_fn = copy_sb16_16;
compute_cdef_dist_fn = compute_cdef_dist;
}
+#else
+ copy_fn = copy_sb16_16;
+ compute_cdef_dist_fn = compute_cdef_dist;
+#endif
DECLARE_ALIGNED(32, uint16_t, inbuf[CDEF_INBUF_SIZE]);
uint16_t *const in = inbuf + CDEF_VBORDER * CDEF_BSTRIDE + CDEF_HBORDER;
diff --git a/test/variance_test.cc b/test/variance_test.cc
index e3c060e..72fc754 100644
--- a/test/variance_test.cc
+++ b/test/variance_test.cc
@@ -2002,6 +2002,15 @@
0)));
#if CONFIG_AV1_HIGHBITDEPTH
+#if HAVE_SSE2
+INSTANTIATE_TEST_SUITE_P(
+ SSE2, MseHBDWxHTest,
+ ::testing::Values(MseHBDWxHParams(3, 3, &aom_mse_wxh_16bit_highbd_sse2, 10),
+ MseHBDWxHParams(3, 2, &aom_mse_wxh_16bit_highbd_sse2, 10),
+ MseHBDWxHParams(2, 3, &aom_mse_wxh_16bit_highbd_sse2, 10),
+ MseHBDWxHParams(2, 2, &aom_mse_wxh_16bit_highbd_sse2,
+ 10)));
+#endif // HAVE_SSE2
#if HAVE_SSE4_1
INSTANTIATE_TEST_SUITE_P(
SSE4_1, AvxSubpelVarianceTest,