Add dual and quad intrinsic support for CDEF MSE computation
This CL attempts to evaluate the aom_mse_16_bit for two 8x8
and four 4x4 blocks at a time. Also, avoiding few of the unpacks
and 32-bit register operations which can dealt within 16-bit
register.
The overall encode time reduction for RT preset is listed below
Encode_time
cpu Testset Reduction(%)
7 rtc 0.765
8 rtc 0.803
Change-Id: I385c7ba00764575a620d563f0f6fed330ff5096d
diff --git a/aom_dsp/aom_dsp_rtcd_defs.pl b/aom_dsp/aom_dsp_rtcd_defs.pl
index aeaf9f1..d7ab4d9 100755
--- a/aom_dsp/aom_dsp_rtcd_defs.pl
+++ b/aom_dsp/aom_dsp_rtcd_defs.pl
@@ -1319,6 +1319,9 @@
add_proto qw/uint64_t/, "aom_mse_wxh_16bit", "uint8_t *dst, int dstride,uint16_t *src, int sstride, int w, int h";
specialize qw/aom_mse_wxh_16bit sse2 avx2/;
+ add_proto qw/uint64_t/, "aom_mse_16xh_16bit", "uint8_t *dst, int dstride,uint16_t *src, int w, int h";
+ specialize qw/aom_mse_16xh_16bit avx2/;
+
foreach (@encoder_block_sizes) {
($w, $h) = @$_;
add_proto qw/unsigned int/, "aom_variance${w}x${h}", "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int ref_stride, unsigned int *sse";
diff --git a/aom_dsp/variance.c b/aom_dsp/variance.c
index d764160..a37f732 100644
--- a/aom_dsp/variance.c
+++ b/aom_dsp/variance.c
@@ -1240,6 +1240,20 @@
return sum;
}
+uint64_t aom_mse_16xh_16bit_c(uint8_t *dst, int dstride, uint16_t *src, int w,
+ int h) {
+ uint16_t *src_temp = src;
+ uint8_t *dst_temp = dst;
+ const int num_blks = 16 / w;
+ int64_t sum = 0;
+ for (int i = 0; i < num_blks; i++) {
+ sum += aom_mse_wxh_16bit_c(dst_temp, dstride, src_temp, w, w, h);
+ dst_temp += w;
+ src_temp += (w * h);
+ }
+ return sum;
+}
+
uint64_t aom_mse_wxh_16bit_highbd_c(uint16_t *dst, int dstride, uint16_t *src,
int sstride, int w, int h) {
uint64_t sum = 0;
diff --git a/aom_dsp/x86/variance_avx2.c b/aom_dsp/x86/variance_avx2.c
index a7203ec..d5eb253 100644
--- a/aom_dsp/x86/variance_avx2.c
+++ b/aom_dsp/x86/variance_avx2.c
@@ -535,7 +535,7 @@
__m128i dst0_4x8, dst1_4x8, dst2_4x8, dst3_4x8, dst_16x8;
__m128i src0_4x16, src1_4x16, src2_4x16, src3_4x16;
__m256i src0_8x16, src1_8x16, dst_16x16, src_16x16;
- __m256i res0_4x64, res1_4x64, res2_4x64, res3_4x64;
+ __m256i res0_4x64, res1_4x64;
__m256i sub_result;
const __m256i zeros = _mm256_broadcastsi128_si256(_mm_setzero_si128());
__m256i square_result = _mm256_broadcastsi128_si256(_mm_setzero_si128());
@@ -558,30 +558,121 @@
_mm256_castsi128_si256(_mm_unpacklo_epi64(src2_4x16, src3_4x16));
src_16x16 = _mm256_permute2x128_si256(src0_8x16, src1_8x16, 0x20);
+ // r15 r14 r13------------r1 r0 - 16 bit
sub_result = _mm256_abs_epi16(_mm256_sub_epi16(src_16x16, dst_16x16));
- src_16x16 = _mm256_unpacklo_epi16(sub_result, zeros);
- dst_16x16 = _mm256_unpackhi_epi16(sub_result, zeros);
+ // s7 s6 s5 s4 s3 s2 s1 s0 - 32bit
+ src_16x16 = _mm256_madd_epi16(sub_result, sub_result);
- src_16x16 = _mm256_madd_epi16(src_16x16, src_16x16); // 32bit store
- dst_16x16 = _mm256_madd_epi16(dst_16x16, dst_16x16); // 32bit store
-
- res0_4x64 = _mm256_unpacklo_epi32(src_16x16, zeros);
- res1_4x64 = _mm256_unpackhi_epi32(src_16x16, zeros);
- res2_4x64 = _mm256_unpacklo_epi32(dst_16x16, zeros);
- res3_4x64 = _mm256_unpackhi_epi32(dst_16x16, zeros);
-
- square_result = _mm256_add_epi64(
- square_result,
- _mm256_add_epi64(
- _mm256_add_epi64(_mm256_add_epi64(res0_4x64, res1_4x64), res2_4x64),
- res3_4x64));
+ // accumulation of result
+ square_result = _mm256_add_epi32(square_result, src_16x16);
}
- const __m128i sum_2x64 =
- _mm_add_epi64(_mm256_castsi256_si128(square_result),
- _mm256_extracti128_si256(square_result, 1));
- const __m128i sum_1x64 = _mm_add_epi64(sum_2x64, _mm_srli_si128(sum_2x64, 8));
- xx_storel_64(&sum, sum_1x64);
+
+ // s5 s4 s1 s0 - 64bit
+ res0_4x64 = _mm256_unpacklo_epi32(square_result, zeros);
+ // s7 s6 s3 s2 - 64bit
+ res1_4x64 = _mm256_unpackhi_epi32(square_result, zeros);
+ // r3 r2 r1 r0 - 64bit
+ res0_4x64 = _mm256_add_epi64(res0_4x64, res1_4x64);
+ // r1+r3 r2+r0 - 64bit
+ const __m128i sum_1x64 =
+ _mm_add_epi64(_mm256_castsi256_si128(res0_4x64),
+ _mm256_extracti128_si256(res0_4x64, 1));
+ xx_storel_64(&sum, _mm_add_epi64(sum_1x64, _mm_srli_si128(sum_1x64, 8)));
+ return sum;
+}
+
+// Compute mse of four consecutive 4x4 blocks.
+// In src buffer, each 4x4 block in a 32x32 filter block is stored sequentially.
+// Hence src_blk_stride is same as block width. Whereas dst buffer is a frame
+// buffer, thus dstride is a frame level stride.
+uint64_t aom_mse_4xh_quad_16bit_avx2(uint8_t *dst, int dstride, uint16_t *src,
+ int src_blk_stride, int h) {
+ uint64_t sum = 0;
+ __m128i dst0_16x8, dst1_16x8, dst2_16x8, dst3_16x8;
+ __m256i dst0_16x16, dst1_16x16, dst2_16x16, dst3_16x16;
+ __m256i res0_4x64, res1_4x64;
+ __m256i sub_result_0, sub_result_1, sub_result_2, sub_result_3;
+ const __m256i zeros = _mm256_broadcastsi128_si256(_mm_setzero_si128());
+ __m256i square_result = zeros;
+ uint16_t *src_temp = src;
+
+ for (int i = 0; i < h; i += 4) {
+ dst0_16x8 = _mm_loadu_si128((__m128i *)(&dst[(i + 0) * dstride]));
+ dst1_16x8 = _mm_loadu_si128((__m128i *)(&dst[(i + 1) * dstride]));
+ dst2_16x8 = _mm_loadu_si128((__m128i *)(&dst[(i + 2) * dstride]));
+ dst3_16x8 = _mm_loadu_si128((__m128i *)(&dst[(i + 3) * dstride]));
+
+ // row0 of 1st,2nd, 3rd and 4th 4x4 blocks- d00 d10 d20 d30
+ dst0_16x16 = _mm256_cvtepu8_epi16(dst0_16x8);
+ // row1 of 1st,2nd, 3rd and 4th 4x4 blocks - d01 d11 d21 d31
+ dst1_16x16 = _mm256_cvtepu8_epi16(dst1_16x8);
+ // row2 of 1st,2nd, 3rd and 4th 4x4 blocks - d02 d12 d22 d32
+ dst2_16x16 = _mm256_cvtepu8_epi16(dst2_16x8);
+ // row3 of 1st,2nd, 3rd and 4th 4x4 blocks - d03 d13 d23 d33
+ dst3_16x16 = _mm256_cvtepu8_epi16(dst3_16x8);
+
+ // All rows of 1st 4x4 block - r00 r01 r02 r03
+ __m256i src0_16x16 = _mm256_loadu_si256((__m256i const *)(&src_temp[0]));
+ // All rows of 2nd 4x4 block - r10 r11 r12 r13
+ __m256i src1_16x16 =
+ _mm256_loadu_si256((__m256i const *)(&src_temp[src_blk_stride]));
+ // All rows of 3rd 4x4 block - r20 r21 r22 r23
+ __m256i src2_16x16 =
+ _mm256_loadu_si256((__m256i const *)(&src_temp[2 * src_blk_stride]));
+ // All rows of 4th 4x4 block - r30 r31 r32 r33
+ __m256i src3_16x16 =
+ _mm256_loadu_si256((__m256i const *)(&src_temp[3 * src_blk_stride]));
+
+ // r00 r10 r02 r12
+ __m256i tmp0_16x16 = _mm256_unpacklo_epi64(src0_16x16, src1_16x16);
+ // r01 r11 r03 r13
+ __m256i tmp1_16x16 = _mm256_unpackhi_epi64(src0_16x16, src1_16x16);
+ // r20 r30 r22 r32
+ __m256i tmp2_16x16 = _mm256_unpacklo_epi64(src2_16x16, src3_16x16);
+ // r21 r31 r23 r33
+ __m256i tmp3_16x16 = _mm256_unpackhi_epi64(src2_16x16, src3_16x16);
+
+ // r00 r10 r20 r30
+ src0_16x16 = _mm256_permute2f128_si256(tmp0_16x16, tmp2_16x16, 0x20);
+ // r01 r11 r21 r31
+ src1_16x16 = _mm256_permute2f128_si256(tmp1_16x16, tmp3_16x16, 0x20);
+ // r02 r12 r22 r32
+ src2_16x16 = _mm256_permute2f128_si256(tmp0_16x16, tmp2_16x16, 0x31);
+ // r03 r13 r23 r33
+ src3_16x16 = _mm256_permute2f128_si256(tmp1_16x16, tmp3_16x16, 0x31);
+
+ // r15 r14 r13------------r1 r0 - 16 bit
+ sub_result_0 = _mm256_abs_epi16(_mm256_sub_epi16(src0_16x16, dst0_16x16));
+ sub_result_1 = _mm256_abs_epi16(_mm256_sub_epi16(src1_16x16, dst1_16x16));
+ sub_result_2 = _mm256_abs_epi16(_mm256_sub_epi16(src2_16x16, dst2_16x16));
+ sub_result_3 = _mm256_abs_epi16(_mm256_sub_epi16(src3_16x16, dst3_16x16));
+
+ // s7 s6 s5 s4 s3 s2 s1 s0 - 32bit
+ src0_16x16 = _mm256_madd_epi16(sub_result_0, sub_result_0);
+ src1_16x16 = _mm256_madd_epi16(sub_result_1, sub_result_1);
+ src2_16x16 = _mm256_madd_epi16(sub_result_2, sub_result_2);
+ src3_16x16 = _mm256_madd_epi16(sub_result_3, sub_result_3);
+
+ // accumulation of result
+ src0_16x16 = _mm256_add_epi32(src0_16x16, src1_16x16);
+ src2_16x16 = _mm256_add_epi32(src2_16x16, src3_16x16);
+ const __m256i square_result_0 = _mm256_add_epi32(src0_16x16, src2_16x16);
+ square_result = _mm256_add_epi32(square_result, square_result_0);
+ src_temp += 16;
+ }
+
+ // s5 s4 s1 s0 - 64bit
+ res0_4x64 = _mm256_unpacklo_epi32(square_result, zeros);
+ // s7 s6 s3 s2 - 64bit
+ res1_4x64 = _mm256_unpackhi_epi32(square_result, zeros);
+ // r3 r2 r1 r0 - 64bit
+ res0_4x64 = _mm256_add_epi64(res0_4x64, res1_4x64);
+ // r1+r3 r2+r0 - 64bit
+ const __m128i sum_1x64 =
+ _mm_add_epi64(_mm256_castsi256_si128(res0_4x64),
+ _mm256_extracti128_si256(res0_4x64, 1));
+ xx_storel_64(&sum, _mm_add_epi64(sum_1x64, _mm_srli_si128(sum_1x64, 8)));
return sum;
}
@@ -590,7 +681,7 @@
uint64_t sum = 0;
__m128i dst0_8x8, dst1_8x8, dst3_16x8;
__m256i src0_8x16, src1_8x16, src_16x16, dst_16x16;
- __m256i res0_4x64, res1_4x64, res2_4x64, res3_4x64;
+ __m256i res0_4x64, res1_4x64;
__m256i sub_result;
const __m256i zeros = _mm256_broadcastsi128_si256(_mm_setzero_si128());
__m256i square_result = _mm256_broadcastsi128_si256(_mm_setzero_si128());
@@ -607,38 +698,98 @@
_mm_loadu_si128((__m128i *)&src[(i + 1) * sstride]));
src_16x16 = _mm256_permute2x128_si256(src0_8x16, src1_8x16, 0x20);
+ // r15 r14 r13 - - - r1 r0 - 16 bit
sub_result = _mm256_abs_epi16(_mm256_sub_epi16(src_16x16, dst_16x16));
- src_16x16 = _mm256_unpacklo_epi16(sub_result, zeros);
- dst_16x16 = _mm256_unpackhi_epi16(sub_result, zeros);
+ // s7 s6 s5 s4 s3 s2 s1 s0 - 32bit
+ src_16x16 = _mm256_madd_epi16(sub_result, sub_result);
- src_16x16 = _mm256_madd_epi16(src_16x16, src_16x16);
- dst_16x16 = _mm256_madd_epi16(dst_16x16, dst_16x16);
-
- res0_4x64 = _mm256_unpacklo_epi32(src_16x16, zeros);
- res1_4x64 = _mm256_unpackhi_epi32(src_16x16, zeros);
- res2_4x64 = _mm256_unpacklo_epi32(dst_16x16, zeros);
- res3_4x64 = _mm256_unpackhi_epi32(dst_16x16, zeros);
-
- square_result = _mm256_add_epi64(
- square_result,
- _mm256_add_epi64(
- _mm256_add_epi64(_mm256_add_epi64(res0_4x64, res1_4x64), res2_4x64),
- res3_4x64));
+ // accumulation of result
+ square_result = _mm256_add_epi32(square_result, src_16x16);
}
- const __m128i sum_2x64 =
- _mm_add_epi64(_mm256_castsi256_si128(square_result),
- _mm256_extracti128_si256(square_result, 1));
- const __m128i sum_1x64 = _mm_add_epi64(sum_2x64, _mm_srli_si128(sum_2x64, 8));
- xx_storel_64(&sum, sum_1x64);
+ // s5 s4 s1 s0 - 64bit
+ res0_4x64 = _mm256_unpacklo_epi32(square_result, zeros);
+ // s7 s6 s3 s2 - 64bit
+ res1_4x64 = _mm256_unpackhi_epi32(square_result, zeros);
+ // r3 r2 r1 r0 - 64bit
+ res0_4x64 = _mm256_add_epi64(res0_4x64, res1_4x64);
+ // r1+r3 r2+r0 - 64bit
+ const __m128i sum_1x64 =
+ _mm_add_epi64(_mm256_castsi256_si128(res0_4x64),
+ _mm256_extracti128_si256(res0_4x64, 1));
+ xx_storel_64(&sum, _mm_add_epi64(sum_1x64, _mm_srli_si128(sum_1x64, 8)));
+ return sum;
+}
+
+// Compute mse of two consecutive 8x8 blocks.
+// In src buffer, each 8x8 block in a 64x64 filter block is stored sequentially.
+// Hence src_blk_stride is same as block width. Whereas dst buffer is a frame
+// buffer, thus dstride is a frame level stride.
+uint64_t aom_mse_8xh_dual_16bit_avx2(uint8_t *dst, int dstride, uint16_t *src,
+ int src_blk_stride, int h) {
+ uint64_t sum = 0;
+ __m128i dst0_16x8, dst1_16x8;
+ __m256i dst0_16x16, dst1_16x16;
+ __m256i res0_4x64, res1_4x64;
+ __m256i sub_result_0, sub_result_1;
+ const __m256i zeros = _mm256_broadcastsi128_si256(_mm_setzero_si128());
+ __m256i square_result = zeros;
+ uint16_t *src_temp = src;
+
+ for (int i = 0; i < h; i += 2) {
+ dst0_16x8 = _mm_loadu_si128((__m128i *)(&dst[(i + 0) * dstride]));
+ dst1_16x8 = _mm_loadu_si128((__m128i *)(&dst[(i + 1) * dstride]));
+
+ // row0 of 1st and 2nd 8x8 block - d00 d10
+ dst0_16x16 = _mm256_cvtepu8_epi16(dst0_16x8);
+ // row1 of 1st and 2nd 8x8 block - d01 d11
+ dst1_16x16 = _mm256_cvtepu8_epi16(dst1_16x8);
+
+ // 2 rows of 1st 8x8 block - r00 r01
+ __m256i src0_16x16 = _mm256_loadu_si256((__m256i const *)(&src_temp[0]));
+ // 2 rows of 2nd 8x8 block - r10 r11
+ __m256i src1_16x16 =
+ _mm256_loadu_si256((__m256i const *)(&src_temp[src_blk_stride]));
+ // r00 r10 - 128bit
+ __m256i tmp0_16x16 =
+ _mm256_permute2f128_si256(src0_16x16, src1_16x16, 0x20);
+ // r01 r11 - 128bit
+ __m256i tmp1_16x16 =
+ _mm256_permute2f128_si256(src0_16x16, src1_16x16, 0x31);
+
+ // r15 r14 r13------------r1 r0 - 16 bit
+ sub_result_0 = _mm256_abs_epi16(_mm256_sub_epi16(tmp0_16x16, dst0_16x16));
+ sub_result_1 = _mm256_abs_epi16(_mm256_sub_epi16(tmp1_16x16, dst1_16x16));
+
+ // s7 s6 s5 s4 s3 s2 s1 s0 - 32bit each
+ src0_16x16 = _mm256_madd_epi16(sub_result_0, sub_result_0);
+ src1_16x16 = _mm256_madd_epi16(sub_result_1, sub_result_1);
+
+ // accumulation of result
+ src0_16x16 = _mm256_add_epi32(src0_16x16, src1_16x16);
+ square_result = _mm256_add_epi32(square_result, src0_16x16);
+ src_temp += 16;
+ }
+
+ // s5 s4 s1 s0 - 64bit
+ res0_4x64 = _mm256_unpacklo_epi32(square_result, zeros);
+ // s7 s6 s3 s2 - 64bit
+ res1_4x64 = _mm256_unpackhi_epi32(square_result, zeros);
+ // r3 r2 r1 r0 - 64bit
+ res0_4x64 = _mm256_add_epi64(res0_4x64, res1_4x64);
+ // r1+r3 r2+r0 - 64bit
+ const __m128i sum_1x64 =
+ _mm_add_epi64(_mm256_castsi256_si128(res0_4x64),
+ _mm256_extracti128_si256(res0_4x64, 1));
+ xx_storel_64(&sum, _mm_add_epi64(sum_1x64, _mm_srli_si128(sum_1x64, 8)));
return sum;
}
uint64_t aom_mse_wxh_16bit_avx2(uint8_t *dst, int dstride, uint16_t *src,
int sstride, int w, int h) {
assert((w == 8 || w == 4) && (h == 8 || h == 4) &&
- "w=8/4 and h=8/4 must satisfy");
+ "w=8/4 and h=8/4 must be satisfied");
switch (w) {
case 4: return aom_mse_4xh_16bit_avx2(dst, dstride, src, sstride, h);
case 8: return aom_mse_8xh_16bit_avx2(dst, dstride, src, sstride, h);
@@ -646,6 +797,21 @@
}
}
+// Computes mse of two 8x8 or four 4x4 consecutive blocks. Luma plane uses 8x8
+// block and Chroma uses 4x4 block. In src buffer, each block in a filter block
+// is stored sequentially. Hence src_blk_stride is same as block width. Whereas
+// dst buffer is a frame buffer, thus dstride is a frame level stride.
+uint64_t aom_mse_16xh_16bit_avx2(uint8_t *dst, int dstride, uint16_t *src,
+ int w, int h) {
+ assert((w == 8 || w == 4) && (h == 8 || h == 4) &&
+ "w=8/4 and h=8/4 must be satisfied");
+ switch (w) {
+ case 4: return aom_mse_4xh_quad_16bit_avx2(dst, dstride, src, w * h, h);
+ case 8: return aom_mse_8xh_dual_16bit_avx2(dst, dstride, src, w * h, h);
+ default: assert(0 && "unsupported width"); return -1;
+ }
+}
+
static INLINE void sum_final_256bit_avx2(__m256i sum_8x16[2], int *const sum) {
const __m256i sum_result_0 = _mm256_hadd_epi16(sum_8x16[0], sum_8x16[1]);
const __m256i sum_result_1 =
diff --git a/av1/encoder/pickcdef.c b/av1/encoder/pickcdef.c
index 3659650..ece0ba3 100644
--- a/av1/encoder/pickcdef.c
+++ b/av1/encoder/pickcdef.c
@@ -267,6 +267,22 @@
return sum >> 2 * coeff_shift;
}
#endif
+
+// Checks dual and quad block processing is applicable for block widths 8 and 4
+// respectively.
+static INLINE int is_dual_or_quad_applicable(cdef_list *dlist, int width,
+ int cdef_count, int bi, int iter) {
+ assert(width == 8 || width == 4);
+ const int blk_offset = (width == 8) ? 1 : 3;
+ if ((iter + blk_offset) >= cdef_count) return 0;
+
+ if (dlist[bi].by == dlist[bi + blk_offset].by &&
+ dlist[bi].bx + blk_offset == dlist[bi + blk_offset].bx)
+ return 1;
+
+ return 0;
+}
+
static uint64_t compute_cdef_dist(void *dst, int dstride, uint16_t *src,
cdef_list *dlist, int cdef_count,
BLOCK_SIZE bsize, int coeff_shift, int row,
@@ -275,18 +291,34 @@
bsize == BLOCK_8X8);
uint64_t sum = 0;
int bi, bx, by;
+ int iter = 0;
+ int inc = 1;
uint8_t *dst8 = (uint8_t *)dst;
uint8_t *dst_buff = &dst8[row * dstride + col];
int src_stride, width, height, width_log2, height_log2;
init_src_params(&src_stride, &width, &height, &width_log2, &height_log2,
bsize);
- for (bi = 0; bi < cdef_count; bi++) {
+
+ const int num_blks = 16 / width;
+ for (bi = 0; bi < cdef_count; bi += inc) {
by = dlist[bi].by;
bx = dlist[bi].bx;
- sum += aom_mse_wxh_16bit(
- &dst_buff[(by << height_log2) * dstride + (bx << width_log2)], dstride,
- &src[bi << (height_log2 + width_log2)], src_stride, width, height);
+ uint16_t *src_tmp = &src[bi << (height_log2 + width_log2)];
+ uint8_t *dst_tmp =
+ &dst_buff[(by << height_log2) * dstride + (bx << width_log2)];
+
+ if (is_dual_or_quad_applicable(dlist, width, cdef_count, bi, iter)) {
+ sum += aom_mse_16xh_16bit(dst_tmp, dstride, src_tmp, width, height);
+ iter += num_blks;
+ inc = num_blks;
+ } else {
+ sum += aom_mse_wxh_16bit(dst_tmp, dstride, src_tmp, src_stride, width,
+ height);
+ iter += 1;
+ inc = 1;
+ }
}
+
return sum >> 2 * coeff_shift;
}
diff --git a/test/variance_test.cc b/test/variance_test.cc
index 8e6abf6..6c0180f 100644
--- a/test/variance_test.cc
+++ b/test/variance_test.cc
@@ -26,11 +26,14 @@
#include "aom_mem/aom_mem.h"
#include "aom_ports/aom_timer.h"
#include "aom_ports/mem.h"
+#include "av1/common/cdef_block.h"
namespace {
typedef uint64_t (*MseWxH16bitFunc)(uint8_t *dst, int dstride, uint16_t *src,
int sstride, int w, int h);
+typedef uint64_t (*Mse16xH16bitFunc)(uint8_t *dst, int dstride, uint16_t *src,
+ int w, int h);
typedef unsigned int (*VarianceMxNFunc)(const uint8_t *a, int a_stride,
const uint8_t *b, int b_stride,
unsigned int *sse);
@@ -513,6 +516,139 @@
}
}
+template <typename FunctionType>
+class Mse16xHTestClass
+ : public ::testing::TestWithParam<TestParams<FunctionType> > {
+ public:
+ // Memory required to compute mse of two 8x8 and four 4x4 blocks assigned for
+ // maximum width 16 and maximum height 8.
+ int mem_size = 16 * 8;
+ virtual void SetUp() {
+ params_ = this->GetParam();
+ rnd_.Reset(ACMRandom::DeterministicSeed());
+ src_ = reinterpret_cast<uint16_t *>(
+ aom_memalign(16, mem_size * sizeof(*src_)));
+ dst_ =
+ reinterpret_cast<uint8_t *>(aom_memalign(16, mem_size * sizeof(*dst_)));
+ ASSERT_NE(src_, nullptr);
+ ASSERT_NE(dst_, nullptr);
+ }
+
+ virtual void TearDown() {
+ aom_free(src_);
+ aom_free(dst_);
+ src_ = nullptr;
+ dst_ = nullptr;
+ }
+
+ uint8_t RandBool() {
+ const uint32_t value = rnd_.Rand8();
+ return (value & 0x1);
+ }
+
+ protected:
+ void RefMatchExtremeTestMse();
+ void RefMatchTestMse();
+ void SpeedTest();
+
+ protected:
+ ACMRandom rnd_;
+ uint8_t *dst_;
+ uint16_t *src_;
+ TestParams<FunctionType> params_;
+
+ // some relay helpers
+ int width() const { return params_.width; }
+ int height() const { return params_.height; }
+ int d_stride() const { return params_.width; }
+};
+
+template <typename Mse16xHFunctionType>
+void Mse16xHTestClass<Mse16xHFunctionType>::SpeedTest() {
+ aom_usec_timer ref_timer, test_timer;
+ double elapsed_time_c = 0.0;
+ double elapsed_time_simd = 0.0;
+ const int loop_count = 10000000;
+ const int w = width();
+ const int h = height();
+ const int dstride = d_stride();
+
+ for (int k = 0; k < mem_size; ++k) {
+ dst_[k] = rnd_.Rand8();
+ // Right shift by 6 is done to generate more input in range of [0,255] than
+ // CDEF_VERY_LARGE
+ int rnd_i10 = rnd_.Rand16() >> 6;
+ src_[k] = (rnd_i10 < 256) ? rnd_i10 : CDEF_VERY_LARGE;
+ }
+
+ aom_usec_timer_start(&ref_timer);
+ for (int i = 0; i < loop_count; i++) {
+ aom_mse_16xh_16bit_c(dst_, dstride, src_, w, h);
+ }
+ aom_usec_timer_mark(&ref_timer);
+ elapsed_time_c = static_cast<double>(aom_usec_timer_elapsed(&ref_timer));
+
+ aom_usec_timer_start(&test_timer);
+ for (int i = 0; i < loop_count; i++) {
+ params_.func(dst_, dstride, src_, w, h);
+ }
+ aom_usec_timer_mark(&test_timer);
+ elapsed_time_simd = static_cast<double>(aom_usec_timer_elapsed(&test_timer));
+
+ printf("%dx%d\tc_time=%lf \t simd_time=%lf \t gain=%.31f\n", width(),
+ height(), elapsed_time_c, elapsed_time_simd,
+ (elapsed_time_c / elapsed_time_simd));
+}
+
+template <typename Mse16xHFunctionType>
+void Mse16xHTestClass<Mse16xHFunctionType>::RefMatchTestMse() {
+ uint64_t mse_ref = 0;
+ uint64_t mse_mod = 0;
+ const int w = width();
+ const int h = height();
+ const int dstride = d_stride();
+
+ for (int i = 0; i < 10; i++) {
+ for (int k = 0; k < mem_size; ++k) {
+ dst_[k] = rnd_.Rand8();
+ // Right shift by 6 is done to generate more input in range of [0,255]
+ // than CDEF_VERY_LARGE
+ int rnd_i10 = rnd_.Rand16() >> 6;
+ src_[k] = (rnd_i10 < 256) ? rnd_i10 : CDEF_VERY_LARGE;
+ }
+
+ API_REGISTER_STATE_CHECK(
+ mse_ref = aom_mse_16xh_16bit_c(dst_, dstride, src_, w, h));
+ API_REGISTER_STATE_CHECK(mse_mod = params_.func(dst_, dstride, src_, w, h));
+ EXPECT_EQ(mse_ref, mse_mod)
+ << "ref mse: " << mse_ref << " mod mse: " << mse_mod;
+ }
+}
+
+template <typename Mse16xHFunctionType>
+void Mse16xHTestClass<Mse16xHFunctionType>::RefMatchExtremeTestMse() {
+ uint64_t mse_ref = 0;
+ uint64_t mse_mod = 0;
+ const int w = width();
+ const int h = height();
+ const int dstride = d_stride();
+ const int iter = 10;
+
+ // Fill the buffers with extreme values
+ for (int i = 0; i < iter; i++) {
+ for (int k = 0; k < mem_size; ++k) {
+ dst_[k] = static_cast<uint8_t>(RandBool() ? 0 : 255);
+ src_[k] = static_cast<uint16_t>(RandBool() ? 0 : CDEF_VERY_LARGE);
+ }
+
+ API_REGISTER_STATE_CHECK(
+ mse_ref = aom_mse_16xh_16bit_c(dst_, dstride, src_, w, h));
+ API_REGISTER_STATE_CHECK(mse_mod = params_.func(dst_, dstride, src_, w, h));
+ EXPECT_EQ(mse_ref, mse_mod)
+ << "ref mse: " << mse_ref << " mod mse: " << mse_mod;
+ }
+}
+
// Main class for testing a function type
template <typename FunctionType>
class MainTestClass
@@ -1327,6 +1463,7 @@
#endif // !CONFIG_REALTIME_ONLY
typedef MseWxHTestClass<MseWxH16bitFunc> MseWxHTest;
+typedef Mse16xHTestClass<Mse16xH16bitFunc> Mse16xHTest;
typedef MainTestClass<Get4x4SseFunc> AvxSseTest;
typedef MainTestClass<VarianceMxNFunc> AvxMseTest;
typedef MainTestClass<VarianceMxNFunc> AvxVarianceTest;
@@ -1339,11 +1476,15 @@
typedef ObmcVarianceTest<ObmcSubpelVarFunc> AvxObmcSubpelVarianceTest;
#endif
typedef TestParams<MseWxH16bitFunc> MseWxHParams;
+typedef TestParams<Mse16xH16bitFunc> Mse16xHParams;
TEST_P(AvxSseTest, RefSse) { RefTestSse(); }
TEST_P(AvxSseTest, MaxSse) { MaxTestSse(); }
TEST_P(MseWxHTest, RefMse) { RefMatchTestMse(); }
TEST_P(MseWxHTest, DISABLED_SpeedMse) { SpeedTest(); }
+TEST_P(Mse16xHTest, RefMse) { RefMatchTestMse(); }
+TEST_P(Mse16xHTest, RefMseExtreme) { RefMatchExtremeTestMse(); }
+TEST_P(Mse16xHTest, DISABLED_SpeedMse) { SpeedTest(); }
TEST_P(AvxMseTest, RefMse) { RefTestMse(); }
TEST_P(AvxMseTest, MaxMse) { MaxTestMse(); }
TEST_P(AvxVarianceTest, Zero) { ZeroTest(); }
@@ -1375,6 +1516,13 @@
MseWxHParams(2, 3, &aom_mse_wxh_16bit_c, 8),
MseWxHParams(2, 2, &aom_mse_wxh_16bit_c, 8)));
+INSTANTIATE_TEST_SUITE_P(
+ C, Mse16xHTest,
+ ::testing::Values(Mse16xHParams(3, 3, &aom_mse_16xh_16bit_c, 8),
+ Mse16xHParams(3, 2, &aom_mse_16xh_16bit_c, 8),
+ Mse16xHParams(2, 3, &aom_mse_16xh_16bit_c, 8),
+ Mse16xHParams(2, 2, &aom_mse_16xh_16bit_c, 8)));
+
INSTANTIATE_TEST_SUITE_P(C, SumOfSquaresTest,
::testing::Values(aom_get_mb_ss_c));
@@ -2740,6 +2888,13 @@
MseWxHParams(2, 3, &aom_mse_wxh_16bit_avx2, 8),
MseWxHParams(2, 2, &aom_mse_wxh_16bit_avx2, 8)));
+INSTANTIATE_TEST_SUITE_P(
+ AVX2, Mse16xHTest,
+ ::testing::Values(Mse16xHParams(3, 3, &aom_mse_16xh_16bit_avx2, 8),
+ Mse16xHParams(3, 2, &aom_mse_16xh_16bit_avx2, 8),
+ Mse16xHParams(2, 3, &aom_mse_16xh_16bit_avx2, 8),
+ Mse16xHParams(2, 2, &aom_mse_16xh_16bit_avx2, 8)));
+
INSTANTIATE_TEST_SUITE_P(AVX2, AvxMseTest,
::testing::Values(MseParams(4, 4,
&aom_mse16x16_avx2)));