Add Neon implementation of aom_comp_mask_pred
Add Neon implementation of aom_comp_mask_pred as well as the
corresponding tests.
Change-Id: If9cca05890d5c782b430ae01447a9c773d0d0c95
diff --git a/aom_dsp/aom_dsp_rtcd_defs.pl b/aom_dsp/aom_dsp_rtcd_defs.pl
index 832c8d1..b3597a6 100755
--- a/aom_dsp/aom_dsp_rtcd_defs.pl
+++ b/aom_dsp/aom_dsp_rtcd_defs.pl
@@ -1965,7 +1965,7 @@
add_proto qw/void aom_comp_mask_pred/, "uint8_t *comp_pred, const uint8_t *pred, int width, int height, const uint8_t *ref, int ref_stride, const uint8_t *mask, int mask_stride, int invert_mask";
- specialize qw/aom_comp_mask_pred ssse3 avx2/;
+ specialize qw/aom_comp_mask_pred ssse3 avx2 neon/;
if (aom_config("CONFIG_AV1_HIGHBITDEPTH") eq "yes") {
add_proto qw/void aom_highbd_comp_mask_pred/, "uint8_t *comp_pred, const uint8_t *pred8, int width, int height, const uint8_t *ref8, int ref_stride, const uint8_t *mask, int mask_stride, int invert_mask";
diff --git a/aom_dsp/arm/avg_pred_neon.c b/aom_dsp/arm/avg_pred_neon.c
index 9262427..04e0904 100644
--- a/aom_dsp/arm/avg_pred_neon.c
+++ b/aom_dsp/arm/avg_pred_neon.c
@@ -14,6 +14,7 @@
#include "config/aom_dsp_rtcd.h"
#include "aom_dsp/arm/mem_neon.h"
+#include "aom_dsp/blend.h"
void aom_comp_avg_pred_neon(uint8_t *comp_pred, const uint8_t *pred, int width,
int height, const uint8_t *ref, int ref_stride) {
@@ -72,3 +73,99 @@
} while (--h != 0);
}
}
+
+void aom_comp_mask_pred_neon(uint8_t *comp_pred, const uint8_t *pred, int width,
+ int height, const uint8_t *ref, int ref_stride,
+ const uint8_t *mask, int mask_stride,
+ int invert_mask) {
+ const uint8_t *src0 = invert_mask ? pred : ref;
+ const uint8_t *src1 = invert_mask ? ref : pred;
+ const int src_stride0 = invert_mask ? width : ref_stride;
+ const int src_stride1 = invert_mask ? ref_stride : width;
+
+ if (width > 8) {
+ const uint8x16_t max_alpha = vdupq_n_u8(AOM_BLEND_A64_MAX_ALPHA);
+ do {
+ const uint8_t *src0_ptr = src0;
+ const uint8_t *src1_ptr = src1;
+ const uint8_t *mask_ptr = mask;
+ uint8_t *comp_pred_ptr = comp_pred;
+ int w = width;
+
+ do {
+ const uint8x16_t s0 = vld1q_u8(src0_ptr);
+ const uint8x16_t s1 = vld1q_u8(src1_ptr);
+ const uint8x16_t m0 = vld1q_u8(mask_ptr);
+
+ uint8x16_t m0_inv = vsubq_u8(max_alpha, m0);
+ uint16x8_t blend_u16_lo = vmull_u8(vget_low_u8(s0), vget_low_u8(m0));
+ uint16x8_t blend_u16_hi = vmull_u8(vget_high_u8(s0), vget_high_u8(m0));
+ blend_u16_lo =
+ vmlal_u8(blend_u16_lo, vget_low_u8(s1), vget_low_u8(m0_inv));
+ blend_u16_hi =
+ vmlal_u8(blend_u16_hi, vget_high_u8(s1), vget_high_u8(m0_inv));
+
+ uint8x8_t blend_u8_lo =
+ vrshrn_n_u16(blend_u16_lo, AOM_BLEND_A64_ROUND_BITS);
+ uint8x8_t blend_u8_hi =
+ vrshrn_n_u16(blend_u16_hi, AOM_BLEND_A64_ROUND_BITS);
+ uint8x16_t blend_u8 = vcombine_u8(blend_u8_lo, blend_u8_hi);
+
+ vst1q_u8(comp_pred_ptr, blend_u8);
+
+ src0_ptr += 16;
+ src1_ptr += 16;
+ mask_ptr += 16;
+ comp_pred_ptr += 16;
+ w -= 16;
+ } while (w != 0);
+
+ src0 += src_stride0;
+ src1 += src_stride1;
+ mask += mask_stride;
+ comp_pred += width;
+ } while (--height != 0);
+ } else if (width == 8) {
+ const uint8x8_t max_alpha = vdup_n_u8(AOM_BLEND_A64_MAX_ALPHA);
+
+ do {
+ const uint8x8_t s0 = vld1_u8(src0);
+ const uint8x8_t s1 = vld1_u8(src1);
+ const uint8x8_t m0 = vld1_u8(mask);
+
+ uint8x8_t m0_inv = vsub_u8(max_alpha, m0);
+ uint16x8_t blend_u16 = vmull_u8(s0, m0);
+ blend_u16 = vmlal_u8(blend_u16, s1, m0_inv);
+ uint8x8_t blend_u8 = vrshrn_n_u16(blend_u16, AOM_BLEND_A64_ROUND_BITS);
+
+ vst1_u8(comp_pred, blend_u8);
+
+ src0 += src_stride0;
+ src1 += src_stride1;
+ mask += mask_stride;
+ comp_pred += 8;
+ } while (--height != 0);
+ } else {
+ const uint8x8_t max_alpha = vdup_n_u8(AOM_BLEND_A64_MAX_ALPHA);
+ int h = height / 2;
+ assert(width == 4);
+
+ do {
+ const uint8x8_t s0 = load_unaligned_u8(src0, src_stride0);
+ const uint8x8_t s1 = load_unaligned_u8(src1, src_stride1);
+ const uint8x8_t m0 = load_unaligned_u8(mask, mask_stride);
+
+ uint8x8_t m0_inv = vsub_u8(max_alpha, m0);
+ uint16x8_t blend_u16 = vmull_u8(s0, m0);
+ blend_u16 = vmlal_u8(blend_u16, s1, m0_inv);
+ uint8x8_t blend_u8 = vrshrn_n_u16(blend_u16, AOM_BLEND_A64_ROUND_BITS);
+
+ vst1_u8(comp_pred, blend_u8);
+
+ src0 += 2 * src_stride0;
+ src1 += 2 * src_stride1;
+ mask += 2 * mask_stride;
+ comp_pred += 8;
+ } while (--h != 0);
+ }
+}
diff --git a/test/comp_mask_variance_test.cc b/test/comp_mask_variance_test.cc
index a83a885..51fcf84 100644
--- a/test/comp_mask_variance_test.cc
+++ b/test/comp_mask_variance_test.cc
@@ -39,7 +39,7 @@
int width, int height, const uint8_t *ref,
int ref_stride);
-#if HAVE_SSSE3 || HAVE_SSE2 || HAVE_AVX2
+#if HAVE_SSSE3 || HAVE_SSE2 || HAVE_AVX2 || HAVE_NEON
const BLOCK_SIZE kCompMaskPredParams[] = {
BLOCK_8X8, BLOCK_8X16, BLOCK_8X32, BLOCK_16X8, BLOCK_16X16,
BLOCK_16X32, BLOCK_32X8, BLOCK_32X16, BLOCK_32X32
@@ -179,6 +179,13 @@
::testing::ValuesIn(kCompMaskPredParams)));
#endif
+#if HAVE_NEON
+INSTANTIATE_TEST_SUITE_P(
+ NEON, AV1CompMaskVarianceTest,
+ ::testing::Combine(::testing::Values(&aom_comp_mask_pred_neon),
+ ::testing::ValuesIn(kCompMaskPredParams)));
+#endif
+
#ifndef aom_comp_mask_pred
// can't run this test if aom_comp_mask_pred is defined to aom_comp_mask_pred_c
class AV1CompMaskUpVarianceTest : public AV1CompMaskVarianceTest {