Add aom_comp_mask_pred_avx2
1. Add AVX2 implementation of aom_comp_mask_pred.
2. For width 8 still use ssse3 version.
3. For other widths(16,32), AVX2 version is 1.2x-2.0x faster
than ssse3 version
Change-Id: I80acc1be54ab21a52f7847e91b1299853add757c
diff --git a/aom_dsp/aom_dsp.cmake b/aom_dsp/aom_dsp.cmake
index 699f8b4..f61af74 100644
--- a/aom_dsp/aom_dsp.cmake
+++ b/aom_dsp/aom_dsp.cmake
@@ -392,6 +392,7 @@
set(AOM_DSP_ENCODER_INTRIN_SSSE3
${AOM_DSP_ENCODER_INTRIN_SSSE3}
"${AOM_ROOT}/aom_dsp/x86/masked_sad_intrin_ssse3.c"
+ "${AOM_ROOT}/aom_dsp/x86/masked_variance_intrin_ssse3.h"
"${AOM_ROOT}/aom_dsp/x86/masked_variance_intrin_ssse3.c")
set(AOM_DSP_ENCODER_INTRIN_SSE2
diff --git a/aom_dsp/aom_dsp_rtcd_defs.pl b/aom_dsp/aom_dsp_rtcd_defs.pl
index cde4d1d..326fd6c 100755
--- a/aom_dsp/aom_dsp_rtcd_defs.pl
+++ b/aom_dsp/aom_dsp_rtcd_defs.pl
@@ -1784,7 +1784,7 @@
add_proto qw/void aom_comp_mask_pred/, "uint8_t *comp_pred, const uint8_t *pred, int width, int height, const uint8_t *ref, int ref_stride, const uint8_t *mask, int mask_stride, int invert_mask";
- specialize qw/aom_comp_mask_pred ssse3/;
+ specialize qw/aom_comp_mask_pred ssse3 avx2/;
add_proto qw/void aom_highbd_comp_mask_pred/, "uint16_t *comp_pred, const uint8_t *pred8, int width, int height, const uint8_t *ref8, int ref_stride, const uint8_t *mask, int mask_stride, int invert_mask";
add_proto qw/void aom_highbd_comp_mask_upsampled_pred/, "uint16_t *comp_pred, const uint8_t *pred8, int width, int height, int subsample_x_q3, int subsample_y_q3, const uint8_t *ref8, int ref_stride, const uint8_t *mask, int mask_stride, int invert_mask, int bd";
diff --git a/aom_dsp/x86/masked_variance_intrin_ssse3.c b/aom_dsp/x86/masked_variance_intrin_ssse3.c
index 468766c..9ae1de6 100644
--- a/aom_dsp/x86/masked_variance_intrin_ssse3.c
+++ b/aom_dsp/x86/masked_variance_intrin_ssse3.c
@@ -15,11 +15,12 @@
#include "./aom_config.h"
#include "./aom_dsp_rtcd.h"
-#include "aom_dsp/blend.h"
#include "aom/aom_integer.h"
-#include "aom_ports/mem.h"
#include "aom_dsp/aom_filter.h"
+#include "aom_dsp/blend.h"
+#include "aom_dsp/x86/masked_variance_intrin_ssse3.h"
#include "aom_dsp/x86/synonyms.h"
+#include "aom_ports/mem.h"
// For width a multiple of 16
static void bilinear_filter(const uint8_t *src, int src_stride, int xoffset,
@@ -1040,32 +1041,6 @@
*sse = _mm_cvtsi128_si32(_mm_srli_si128(sum, 4));
}
-static INLINE void comp_mask_pred_16_ssse3(const uint8_t *src0,
- const uint8_t *src1,
- const uint8_t *mask, uint8_t *dst) {
- const __m128i alpha_max = _mm_set1_epi8(AOM_BLEND_A64_MAX_ALPHA);
- const __m128i round_offset =
- _mm_set1_epi16(1 << (15 - AOM_BLEND_A64_ROUND_BITS));
-
- const __m128i sA0 = _mm_lddqu_si128((const __m128i *)(src0));
- const __m128i sA1 = _mm_lddqu_si128((const __m128i *)(src1));
- const __m128i aA = _mm_load_si128((const __m128i *)(mask));
-
- const __m128i maA = _mm_sub_epi8(alpha_max, aA);
-
- const __m128i ssAL = _mm_unpacklo_epi8(sA0, sA1);
- const __m128i aaAL = _mm_unpacklo_epi8(aA, maA);
- const __m128i ssAH = _mm_unpackhi_epi8(sA0, sA1);
- const __m128i aaAH = _mm_unpackhi_epi8(aA, maA);
-
- const __m128i blendAL = _mm_maddubs_epi16(ssAL, aaAL);
- const __m128i blendAH = _mm_maddubs_epi16(ssAH, aaAH);
-
- const __m128i roundAL = _mm_mulhrs_epi16(blendAL, round_offset);
- const __m128i roundAH = _mm_mulhrs_epi16(blendAH, round_offset);
- _mm_store_si128((__m128i *)dst, _mm_packus_epi16(roundAL, roundAH));
-}
-
void aom_comp_mask_pred_ssse3(uint8_t *comp_pred, const uint8_t *pred,
int width, int height, const uint8_t *ref,
int ref_stride, const uint8_t *mask,
@@ -1074,46 +1049,11 @@
const uint8_t *src1 = invert_mask ? ref : pred;
const int stride0 = invert_mask ? width : ref_stride;
const int stride1 = invert_mask ? ref_stride : width;
-
- const __m128i alpha_max = _mm_set1_epi8(AOM_BLEND_A64_MAX_ALPHA);
- const __m128i round_offset =
- _mm_set1_epi16(1 << (15 - AOM_BLEND_A64_ROUND_BITS));
assert(height % 2 == 0);
- assert(width % 8 == 0);
int i = 0;
if (width == 8) {
- do {
- // odd line A
- const __m128i sA0 = _mm_loadl_epi64((const __m128i *)(src0));
- const __m128i sA1 = _mm_loadl_epi64((const __m128i *)(src1));
- const __m128i aA = _mm_loadl_epi64((const __m128i *)(mask));
-
- // even line B
- const __m128i sB0 = _mm_loadl_epi64((const __m128i *)(src0 + stride0));
- const __m128i sB1 = _mm_loadl_epi64((const __m128i *)(src1 + stride1));
- const __m128i a = _mm_castps_si128(_mm_loadh_pi(
- _mm_castsi128_ps(aA), (const __m64 *)(mask + mask_stride)));
-
- const __m128i ssA = _mm_unpacklo_epi8(sA0, sA1);
- const __m128i ssB = _mm_unpacklo_epi8(sB0, sB1);
-
- const __m128i ma = _mm_sub_epi8(alpha_max, a);
- const __m128i aaA = _mm_unpacklo_epi8(a, ma);
- const __m128i aaB = _mm_unpackhi_epi8(a, ma);
-
- const __m128i blendA = _mm_maddubs_epi16(ssA, aaA);
- const __m128i blendB = _mm_maddubs_epi16(ssB, aaB);
- const __m128i roundA = _mm_mulhrs_epi16(blendA, round_offset);
- const __m128i roundB = _mm_mulhrs_epi16(blendB, round_offset);
- const __m128i round = _mm_packus_epi16(roundA, roundB);
- // comp_pred's stride == width == 8
- _mm_store_si128((__m128i *)(comp_pred), round);
- comp_pred += (width << 1);
- src0 += (stride0 << 1);
- src1 += (stride1 << 1);
- mask += (mask_stride << 1);
- i += 2;
- } while (i < height);
+ comp_mask_pred_8_ssse3(comp_pred, height, src0, stride0, src1, stride1,
+ mask, mask_stride);
} else if (width == 16) {
do {
comp_mask_pred_16_ssse3(src0, src1, mask, comp_pred);
@@ -1126,6 +1066,7 @@
i += 2;
} while (i < height);
} else { // width == 32
+ assert(width == 32);
do {
comp_mask_pred_16_ssse3(src0, src1, mask, comp_pred);
comp_mask_pred_16_ssse3(src0 + 16, src1 + 16, mask + 16, comp_pred + 16);
diff --git a/aom_dsp/x86/masked_variance_intrin_ssse3.h b/aom_dsp/x86/masked_variance_intrin_ssse3.h
new file mode 100644
index 0000000..4d058a1
--- /dev/null
+++ b/aom_dsp/x86/masked_variance_intrin_ssse3.h
@@ -0,0 +1,91 @@
+/*
+ * Copyright (c) 2018, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#ifndef _AOM_DSP_X86_MASKED_VARIANCE_INTRIN_SSSE3_H
+#define _AOM_DSP_X86_MASKED_VARIANCE_INTRIN_SSSE3_H
+
+#include <stdlib.h>
+#include <string.h>
+#include <tmmintrin.h>
+
+#include "./aom_config.h"
+#include "./aom_dsp_rtcd.h"
+#include "aom_dsp/blend.h"
+
+static INLINE void comp_mask_pred_16_ssse3(const uint8_t *src0,
+ const uint8_t *src1,
+ const uint8_t *mask, uint8_t *dst) {
+ const __m128i alpha_max = _mm_set1_epi8(AOM_BLEND_A64_MAX_ALPHA);
+ const __m128i round_offset =
+ _mm_set1_epi16(1 << (15 - AOM_BLEND_A64_ROUND_BITS));
+
+ const __m128i sA0 = _mm_lddqu_si128((const __m128i *)(src0));
+ const __m128i sA1 = _mm_lddqu_si128((const __m128i *)(src1));
+ const __m128i aA = _mm_load_si128((const __m128i *)(mask));
+
+ const __m128i maA = _mm_sub_epi8(alpha_max, aA);
+
+ const __m128i ssAL = _mm_unpacklo_epi8(sA0, sA1);
+ const __m128i aaAL = _mm_unpacklo_epi8(aA, maA);
+ const __m128i ssAH = _mm_unpackhi_epi8(sA0, sA1);
+ const __m128i aaAH = _mm_unpackhi_epi8(aA, maA);
+
+ const __m128i blendAL = _mm_maddubs_epi16(ssAL, aaAL);
+ const __m128i blendAH = _mm_maddubs_epi16(ssAH, aaAH);
+
+ const __m128i roundAL = _mm_mulhrs_epi16(blendAL, round_offset);
+ const __m128i roundAH = _mm_mulhrs_epi16(blendAH, round_offset);
+ _mm_store_si128((__m128i *)dst, _mm_packus_epi16(roundAL, roundAH));
+}
+
+static INLINE void comp_mask_pred_8_ssse3(uint8_t *comp_pred, int height,
+ const uint8_t *src0, int stride0,
+ const uint8_t *src1, int stride1,
+ const uint8_t *mask,
+ int mask_stride) {
+ int i = 0;
+ const __m128i alpha_max = _mm_set1_epi8(AOM_BLEND_A64_MAX_ALPHA);
+ const __m128i round_offset =
+ _mm_set1_epi16(1 << (15 - AOM_BLEND_A64_ROUND_BITS));
+ do {
+ // odd line A
+ const __m128i sA0 = _mm_loadl_epi64((const __m128i *)(src0));
+ const __m128i sA1 = _mm_loadl_epi64((const __m128i *)(src1));
+ const __m128i aA = _mm_loadl_epi64((const __m128i *)(mask));
+ // even line B
+ const __m128i sB0 = _mm_loadl_epi64((const __m128i *)(src0 + stride0));
+ const __m128i sB1 = _mm_loadl_epi64((const __m128i *)(src1 + stride1));
+ const __m128i a = _mm_castps_si128(_mm_loadh_pi(
+ _mm_castsi128_ps(aA), (const __m64 *)(mask + mask_stride)));
+
+ const __m128i ssA = _mm_unpacklo_epi8(sA0, sA1);
+ const __m128i ssB = _mm_unpacklo_epi8(sB0, sB1);
+
+ const __m128i ma = _mm_sub_epi8(alpha_max, a);
+ const __m128i aaA = _mm_unpacklo_epi8(a, ma);
+ const __m128i aaB = _mm_unpackhi_epi8(a, ma);
+
+ const __m128i blendA = _mm_maddubs_epi16(ssA, aaA);
+ const __m128i blendB = _mm_maddubs_epi16(ssB, aaB);
+ const __m128i roundA = _mm_mulhrs_epi16(blendA, round_offset);
+ const __m128i roundB = _mm_mulhrs_epi16(blendB, round_offset);
+ const __m128i round = _mm_packus_epi16(roundA, roundB);
+ // comp_pred's stride == width == 8
+ _mm_store_si128((__m128i *)(comp_pred), round);
+ comp_pred += (8 << 1);
+ src0 += (stride0 << 1);
+ src1 += (stride1 << 1);
+ mask += (mask_stride << 1);
+ i += 2;
+ } while (i < height);
+}
+
+#endif
diff --git a/aom_dsp/x86/variance_avx2.c b/aom_dsp/x86/variance_avx2.c
index 18a70df..a041bba 100644
--- a/aom_dsp/x86/variance_avx2.c
+++ b/aom_dsp/x86/variance_avx2.c
@@ -11,6 +11,7 @@
#include <immintrin.h>
#include "./aom_dsp_rtcd.h"
+#include "aom_dsp/x86/masked_variance_intrin_ssse3.h"
typedef void (*get_var_avx2)(const uint8_t *src, int src_stride,
const uint8_t *ref, int ref_stride,
@@ -190,3 +191,87 @@
_mm256_zeroupper();
return variance;
}
+
+static INLINE __m256i mm256_loadu2(const uint8_t *p0, const uint8_t *p1) {
+ const __m256i d =
+ _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)p1));
+ return _mm256_insertf128_si256(d, _mm_loadu_si128((const __m128i *)p0), 1);
+}
+
+static INLINE void comp_mask_pred_line_avx2(const __m256i s0, const __m256i s1,
+ const __m256i a,
+ uint8_t *comp_pred) {
+ const __m256i alpha_max = _mm256_set1_epi8(AOM_BLEND_A64_MAX_ALPHA);
+ const int16_t round_bits = 15 - AOM_BLEND_A64_ROUND_BITS;
+ const __m256i round_offset = _mm256_set1_epi16(1 << (round_bits));
+
+ const __m256i ma = _mm256_sub_epi8(alpha_max, a);
+
+ const __m256i ssAL = _mm256_unpacklo_epi8(s0, s1);
+ const __m256i aaAL = _mm256_unpacklo_epi8(a, ma);
+ const __m256i ssAH = _mm256_unpackhi_epi8(s0, s1);
+ const __m256i aaAH = _mm256_unpackhi_epi8(a, ma);
+
+ const __m256i blendAL = _mm256_maddubs_epi16(ssAL, aaAL);
+ const __m256i blendAH = _mm256_maddubs_epi16(ssAH, aaAH);
+ const __m256i roundAL = _mm256_mulhrs_epi16(blendAL, round_offset);
+ const __m256i roundAH = _mm256_mulhrs_epi16(blendAH, round_offset);
+
+ const __m256i roundA = _mm256_packus_epi16(roundAL, roundAH);
+ _mm256_storeu_si256((__m256i *)(comp_pred), roundA);
+}
+
+void aom_comp_mask_pred_avx2(uint8_t *comp_pred, const uint8_t *pred, int width,
+ int height, const uint8_t *ref, int ref_stride,
+ const uint8_t *mask, int mask_stride,
+ int invert_mask) {
+ int i = 0;
+ const uint8_t *src0 = invert_mask ? pred : ref;
+ const uint8_t *src1 = invert_mask ? ref : pred;
+ const int stride0 = invert_mask ? width : ref_stride;
+ const int stride1 = invert_mask ? ref_stride : width;
+ if (width == 8) {
+ comp_mask_pred_8_ssse3(comp_pred, height, src0, stride0, src1, stride1,
+ mask, mask_stride);
+ } else if (width == 16) {
+ do {
+ const __m256i sA0 = mm256_loadu2(src0 + stride0, src0);
+ const __m256i sA1 = mm256_loadu2(src1 + stride1, src1);
+ const __m256i aA = mm256_loadu2(mask + mask_stride, mask);
+ src0 += (stride0 << 1);
+ src1 += (stride1 << 1);
+ mask += (mask_stride << 1);
+ const __m256i sB0 = mm256_loadu2(src0 + stride0, src0);
+ const __m256i sB1 = mm256_loadu2(src1 + stride1, src1);
+ const __m256i aB = mm256_loadu2(mask + mask_stride, mask);
+ src0 += (stride0 << 1);
+ src1 += (stride1 << 1);
+ mask += (mask_stride << 1);
+ // comp_pred's stride == width == 16
+ comp_mask_pred_line_avx2(sA0, sA1, aA, comp_pred);
+ comp_mask_pred_line_avx2(sB0, sB1, aB, comp_pred + 32);
+ comp_pred += (16 << 2);
+ i += 4;
+ } while (i < height);
+ } else { // for width == 32
+ do {
+ const __m256i sA0 = _mm256_lddqu_si256((const __m256i *)(src0));
+ const __m256i sA1 = _mm256_lddqu_si256((const __m256i *)(src1));
+ const __m256i aA = _mm256_lddqu_si256((const __m256i *)(mask));
+
+ const __m256i sB0 = _mm256_lddqu_si256((const __m256i *)(src0 + stride0));
+ const __m256i sB1 = _mm256_lddqu_si256((const __m256i *)(src1 + stride1));
+ const __m256i aB =
+ _mm256_lddqu_si256((const __m256i *)(mask + mask_stride));
+
+ comp_mask_pred_line_avx2(sA0, sA1, aA, comp_pred);
+ comp_mask_pred_line_avx2(sB0, sB1, aB, comp_pred + 32);
+ comp_pred += (32 << 1);
+
+ src0 += (stride0 << 1);
+ src1 += (stride1 << 1);
+ mask += (mask_stride << 1);
+ i += 2;
+ } while (i < height);
+ }
+}
diff --git a/test/comp_mask_variance_test.cc b/test/comp_mask_variance_test.cc
index acb0e8c..b6fc75b 100644
--- a/test/comp_mask_variance_test.cc
+++ b/test/comp_mask_variance_test.cc
@@ -11,7 +11,6 @@
#include <cstdlib>
#include <new>
-#include <vector>
#include "./aom_config.h"
#include "./aom_dsp_rtcd.h"
@@ -28,15 +27,13 @@
#include "test/util.h"
#include "third_party/googletest/src/googletest/include/gtest/gtest.h"
-using std::vector;
-
namespace AV1CompMaskVariance {
typedef void (*comp_mask_pred_func)(uint8_t *comp_pred, const uint8_t *pred,
int width, int height, const uint8_t *ref,
int ref_stride, const uint8_t *mask,
int mask_stride, int invert_mask);
-const BLOCK_SIZE valid_bsize[] = {
+const BLOCK_SIZE kValidBlockSize[] = {
BLOCK_8X8, BLOCK_8X16, BLOCK_8X32, BLOCK_16X8, BLOCK_16X16,
BLOCK_16X32, BLOCK_32X8, BLOCK_32X16, BLOCK_32X32,
};
@@ -53,12 +50,13 @@
protected:
void RunCheckOutput(comp_mask_pred_func test_impl, BLOCK_SIZE bsize, int inv);
void RunSpeedTest(comp_mask_pred_func test_impl, BLOCK_SIZE bsize);
- bool CheckResult(int w, int h) {
- for (int i = 0; i < h; ++i) {
- for (int j = 0; j < w; ++j) {
- int idx = i * w + j;
+ bool CheckResult(int width, int height) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ const int idx = y * width + x;
if (comp_pred1_[idx] != comp_pred2_[idx]) {
- printf("%dx%d mismatch @%d(%d,%d) ", w, h, idx, i, j);
+ printf("%dx%d mismatch @%d(%d,%d) ", width, height, idx, y, x);
+ printf("%d != %d ", comp_pred1_[idx], comp_pred2_[idx]);
return false;
}
}
@@ -160,7 +158,14 @@
INSTANTIATE_TEST_CASE_P(
SSSE3, AV1CompMaskVarianceTest,
::testing::Combine(::testing::Values(&aom_comp_mask_pred_ssse3),
- ::testing::ValuesIn(valid_bsize)));
+ ::testing::ValuesIn(kValidBlockSize)));
+#endif
+
+#if HAVE_AVX2
+INSTANTIATE_TEST_CASE_P(
+ AVX2, AV1CompMaskVarianceTest,
+ ::testing::Combine(::testing::Values(&aom_comp_mask_pred_avx2),
+ ::testing::ValuesIn(kValidBlockSize)));
#endif
#ifndef aom_comp_mask_pred
@@ -249,7 +254,15 @@
INSTANTIATE_TEST_CASE_P(
SSSE3, AV1CompMaskUpVarianceTest,
::testing::Combine(::testing::Values(&aom_comp_mask_pred_ssse3),
- ::testing::ValuesIn(valid_bsize)));
+ ::testing::ValuesIn(kValidBlockSize)));
#endif
+
+#if HAVE_AVX2
+INSTANTIATE_TEST_CASE_P(
+ AVX2, AV1CompMaskUpVarianceTest,
+ ::testing::Combine(::testing::Values(&aom_comp_mask_pred_avx2),
+ ::testing::ValuesIn(kValidBlockSize)));
#endif
+
+#endif // ifndef aom_comp_mask_pred
} // namespace AV1CompMaskVariance