Supporting widths>32 for aom_comp_mask_pred()
for AVX2, SSE3
Update the unit tests for widths>32
BUG=aomedia:2848
Change-Id: Ib5829dae3818b28a6ab2644887fa18bae78d9d44
diff --git a/aom_dsp/x86/masked_variance_intrin_ssse3.c b/aom_dsp/x86/masked_variance_intrin_ssse3.c
index fa93f0d..c859628 100644
--- a/aom_dsp/x86/masked_variance_intrin_ssse3.c
+++ b/aom_dsp/x86/masked_variance_intrin_ssse3.c
@@ -1052,12 +1052,15 @@
mask += (mask_stride << 1);
i += 2;
} while (i < height);
- } else { // width == 32
- assert(width == 32);
+ } else {
do {
- comp_mask_pred_16_ssse3(src0, src1, mask, comp_pred);
- comp_mask_pred_16_ssse3(src0 + 16, src1 + 16, mask + 16, comp_pred + 16);
- comp_pred += (width);
+ for (int x = 0; x < width / 32; x++) {
+ comp_mask_pred_16_ssse3(src0 + x * 32, src1 + x * 32, mask + x * 32,
+ comp_pred);
+ comp_mask_pred_16_ssse3(src0 + x * 32 + 16, src1 + x * 32 + 16,
+ mask + x * 32 + 16, comp_pred + 16);
+ comp_pred += 32;
+ }
src0 += (stride0);
src1 += (stride1);
mask += (mask_stride);
diff --git a/aom_dsp/x86/variance_avx2.c b/aom_dsp/x86/variance_avx2.c
index 7510c38..3558070 100644
--- a/aom_dsp/x86/variance_avx2.c
+++ b/aom_dsp/x86/variance_avx2.c
@@ -395,25 +395,22 @@
comp_pred += (16 << 2);
i += 4;
} while (i < height);
- } else { // for width == 32
+ } else {
do {
- const __m256i sA0 = _mm256_lddqu_si256((const __m256i *)(src0));
- const __m256i sA1 = _mm256_lddqu_si256((const __m256i *)(src1));
- const __m256i aA = _mm256_lddqu_si256((const __m256i *)(mask));
+ for (int x = 0; x < width / 32; x++) {
+ const __m256i sA0 =
+ _mm256_lddqu_si256((const __m256i *)(src0 + x * 32));
+ const __m256i sA1 =
+ _mm256_lddqu_si256((const __m256i *)(src1 + x * 32));
+ const __m256i aA = _mm256_lddqu_si256((const __m256i *)(mask + x * 32));
- const __m256i sB0 = _mm256_lddqu_si256((const __m256i *)(src0 + stride0));
- const __m256i sB1 = _mm256_lddqu_si256((const __m256i *)(src1 + stride1));
- const __m256i aB =
- _mm256_lddqu_si256((const __m256i *)(mask + mask_stride));
-
- comp_mask_pred_line_avx2(sA0, sA1, aA, comp_pred);
- comp_mask_pred_line_avx2(sB0, sB1, aB, comp_pred + 32);
- comp_pred += (32 << 1);
-
- src0 += (stride0 << 1);
- src1 += (stride1 << 1);
- mask += (mask_stride << 1);
- i += 2;
+ comp_mask_pred_line_avx2(sA0, sA1, aA, comp_pred);
+ comp_pred += 32;
+ }
+ src0 += stride0;
+ src1 += stride1;
+ mask += mask_stride;
+ i++;
} while (i < height);
}
}
diff --git a/av1/encoder/mcomp.c b/av1/encoder/mcomp.c
index c636843..704f466 100644
--- a/av1/encoder/mcomp.c
+++ b/av1/encoder/mcomp.c
@@ -2335,7 +2335,7 @@
DECLARE_ALIGNED(16, uint8_t, pred[MAX_SB_SQUARE]);
if (second_pred != NULL) {
if (mask) {
- aom_comp_mask_upsampled_pred_c(
+ aom_comp_mask_upsampled_pred(
xd, cm, mi_row, mi_col, this_mv, pred, second_pred, w, h,
subpel_x_q3, subpel_y_q3, ref, ref_stride, mask, mask_stride,
invert_mask, subpel_search_type);
diff --git a/test/comp_mask_variance_test.cc b/test/comp_mask_variance_test.cc
index fec9248..d07245c 100644
--- a/test/comp_mask_variance_test.cc
+++ b/test/comp_mask_variance_test.cc
@@ -37,8 +37,10 @@
#if HAVE_SSSE3 || HAVE_SSE2 || HAVE_AVX2
const BLOCK_SIZE kValidBlockSize[] = {
- BLOCK_8X8, BLOCK_8X16, BLOCK_8X32, BLOCK_16X8, BLOCK_16X16,
- BLOCK_16X32, BLOCK_32X8, BLOCK_32X16, BLOCK_32X32,
+ BLOCK_8X8, BLOCK_8X16, BLOCK_8X32, BLOCK_16X8, BLOCK_16X16,
+ BLOCK_16X32, BLOCK_32X8, BLOCK_32X16, BLOCK_32X32, BLOCK_32X64,
+ BLOCK_64X32, BLOCK_64X64, BLOCK_64X128, BLOCK_128X64, BLOCK_128X128,
+ BLOCK_8X32, BLOCK_32X8, BLOCK_16X64, BLOCK_64X16
};
#endif
typedef std::tuple<comp_mask_pred_func, BLOCK_SIZE> CompMaskPredParam;