Supporting widths>32 for aom_comp_mask_pred()
for AVX2, SSE3
Update the unit tests for widths>32
BUG=aomedia:2848
Change-Id: Ib5829dae3818b28a6ab2644887fa18bae78d9d44
(cherry picked from commit c05e2fc986da30639c56c3f6bb83adfd38a1a77e)
diff --git a/aom_dsp/x86/masked_variance_intrin_ssse3.c b/aom_dsp/x86/masked_variance_intrin_ssse3.c
index ebf4631..8811829 100644
--- a/aom_dsp/x86/masked_variance_intrin_ssse3.c
+++ b/aom_dsp/x86/masked_variance_intrin_ssse3.c
@@ -1050,12 +1050,15 @@
mask += (mask_stride << 1);
i += 2;
} while (i < height);
- } else { // width == 32
- assert(width == 32);
+ } else {
do {
- comp_mask_pred_16_ssse3(src0, src1, mask, comp_pred);
- comp_mask_pred_16_ssse3(src0 + 16, src1 + 16, mask + 16, comp_pred + 16);
- comp_pred += (width);
+ for (int x = 0; x < width / 32; x++) {
+ comp_mask_pred_16_ssse3(src0 + x * 32, src1 + x * 32, mask + x * 32,
+ comp_pred);
+ comp_mask_pred_16_ssse3(src0 + x * 32 + 16, src1 + x * 32 + 16,
+ mask + x * 32 + 16, comp_pred + 16);
+ comp_pred += 32;
+ }
src0 += (stride0);
src1 += (stride1);
mask += (mask_stride);
diff --git a/aom_dsp/x86/variance_avx2.c b/aom_dsp/x86/variance_avx2.c
index 7510c38..3558070 100644
--- a/aom_dsp/x86/variance_avx2.c
+++ b/aom_dsp/x86/variance_avx2.c
@@ -395,25 +395,22 @@
comp_pred += (16 << 2);
i += 4;
} while (i < height);
- } else { // for width == 32
+ } else {
do {
- const __m256i sA0 = _mm256_lddqu_si256((const __m256i *)(src0));
- const __m256i sA1 = _mm256_lddqu_si256((const __m256i *)(src1));
- const __m256i aA = _mm256_lddqu_si256((const __m256i *)(mask));
+ for (int x = 0; x < width / 32; x++) {
+ const __m256i sA0 =
+ _mm256_lddqu_si256((const __m256i *)(src0 + x * 32));
+ const __m256i sA1 =
+ _mm256_lddqu_si256((const __m256i *)(src1 + x * 32));
+ const __m256i aA = _mm256_lddqu_si256((const __m256i *)(mask + x * 32));
- const __m256i sB0 = _mm256_lddqu_si256((const __m256i *)(src0 + stride0));
- const __m256i sB1 = _mm256_lddqu_si256((const __m256i *)(src1 + stride1));
- const __m256i aB =
- _mm256_lddqu_si256((const __m256i *)(mask + mask_stride));
-
- comp_mask_pred_line_avx2(sA0, sA1, aA, comp_pred);
- comp_mask_pred_line_avx2(sB0, sB1, aB, comp_pred + 32);
- comp_pred += (32 << 1);
-
- src0 += (stride0 << 1);
- src1 += (stride1 << 1);
- mask += (mask_stride << 1);
- i += 2;
+ comp_mask_pred_line_avx2(sA0, sA1, aA, comp_pred);
+ comp_pred += 32;
+ }
+ src0 += stride0;
+ src1 += stride1;
+ mask += mask_stride;
+ i++;
} while (i < height);
}
}
diff --git a/test/comp_mask_variance_test.cc b/test/comp_mask_variance_test.cc
index 05d69f4..2e9ca17 100644
--- a/test/comp_mask_variance_test.cc
+++ b/test/comp_mask_variance_test.cc
@@ -37,8 +37,10 @@
#if HAVE_SSSE3 || HAVE_SSE2 || HAVE_AVX2
const BLOCK_SIZE kValidBlockSize[] = {
- BLOCK_8X8, BLOCK_8X16, BLOCK_8X32, BLOCK_16X8, BLOCK_16X16,
- BLOCK_16X32, BLOCK_32X8, BLOCK_32X16, BLOCK_32X32,
+ BLOCK_8X8, BLOCK_8X16, BLOCK_8X32, BLOCK_16X8, BLOCK_16X16,
+ BLOCK_16X32, BLOCK_32X8, BLOCK_32X16, BLOCK_32X32, BLOCK_32X64,
+ BLOCK_64X32, BLOCK_64X64, BLOCK_64X128, BLOCK_128X64, BLOCK_128X128,
+ BLOCK_8X32, BLOCK_32X8, BLOCK_16X64, BLOCK_64X16
};
#endif
typedef std::tuple<comp_mask_pred_func, BLOCK_SIZE> CompMaskPredParam;