Add Neon implementation of aom_comp_mask_pred

Add Neon implementation of aom_comp_mask_pred as well as the
corresponding tests.

Change-Id: If9cca05890d5c782b430ae01447a9c773d0d0c95
diff --git a/aom_dsp/aom_dsp_rtcd_defs.pl b/aom_dsp/aom_dsp_rtcd_defs.pl
index 832c8d1..b3597a6 100755
--- a/aom_dsp/aom_dsp_rtcd_defs.pl
+++ b/aom_dsp/aom_dsp_rtcd_defs.pl
@@ -1965,7 +1965,7 @@
 
 
   add_proto qw/void aom_comp_mask_pred/, "uint8_t *comp_pred, const uint8_t *pred, int width, int height, const uint8_t *ref, int ref_stride, const uint8_t *mask, int mask_stride, int invert_mask";
-  specialize qw/aom_comp_mask_pred ssse3 avx2/;
+  specialize qw/aom_comp_mask_pred ssse3 avx2 neon/;
 
   if (aom_config("CONFIG_AV1_HIGHBITDEPTH") eq "yes") {
     add_proto qw/void aom_highbd_comp_mask_pred/, "uint8_t *comp_pred, const uint8_t *pred8, int width, int height, const uint8_t *ref8, int ref_stride, const uint8_t *mask, int mask_stride, int invert_mask";
diff --git a/aom_dsp/arm/avg_pred_neon.c b/aom_dsp/arm/avg_pred_neon.c
index 9262427..04e0904 100644
--- a/aom_dsp/arm/avg_pred_neon.c
+++ b/aom_dsp/arm/avg_pred_neon.c
@@ -14,6 +14,7 @@
 
 #include "config/aom_dsp_rtcd.h"
 #include "aom_dsp/arm/mem_neon.h"
+#include "aom_dsp/blend.h"
 
 void aom_comp_avg_pred_neon(uint8_t *comp_pred, const uint8_t *pred, int width,
                             int height, const uint8_t *ref, int ref_stride) {
@@ -72,3 +73,99 @@
     } while (--h != 0);
   }
 }
+
+void aom_comp_mask_pred_neon(uint8_t *comp_pred, const uint8_t *pred, int width,
+                             int height, const uint8_t *ref, int ref_stride,
+                             const uint8_t *mask, int mask_stride,
+                             int invert_mask) {
+  const uint8_t *src0 = invert_mask ? pred : ref;
+  const uint8_t *src1 = invert_mask ? ref : pred;
+  const int src_stride0 = invert_mask ? width : ref_stride;
+  const int src_stride1 = invert_mask ? ref_stride : width;
+
+  if (width > 8) {
+    const uint8x16_t max_alpha = vdupq_n_u8(AOM_BLEND_A64_MAX_ALPHA);
+    do {
+      const uint8_t *src0_ptr = src0;
+      const uint8_t *src1_ptr = src1;
+      const uint8_t *mask_ptr = mask;
+      uint8_t *comp_pred_ptr = comp_pred;
+      int w = width;
+
+      do {
+        const uint8x16_t s0 = vld1q_u8(src0_ptr);
+        const uint8x16_t s1 = vld1q_u8(src1_ptr);
+        const uint8x16_t m0 = vld1q_u8(mask_ptr);
+
+        uint8x16_t m0_inv = vsubq_u8(max_alpha, m0);
+        uint16x8_t blend_u16_lo = vmull_u8(vget_low_u8(s0), vget_low_u8(m0));
+        uint16x8_t blend_u16_hi = vmull_u8(vget_high_u8(s0), vget_high_u8(m0));
+        blend_u16_lo =
+            vmlal_u8(blend_u16_lo, vget_low_u8(s1), vget_low_u8(m0_inv));
+        blend_u16_hi =
+            vmlal_u8(blend_u16_hi, vget_high_u8(s1), vget_high_u8(m0_inv));
+
+        uint8x8_t blend_u8_lo =
+            vrshrn_n_u16(blend_u16_lo, AOM_BLEND_A64_ROUND_BITS);
+        uint8x8_t blend_u8_hi =
+            vrshrn_n_u16(blend_u16_hi, AOM_BLEND_A64_ROUND_BITS);
+        uint8x16_t blend_u8 = vcombine_u8(blend_u8_lo, blend_u8_hi);
+
+        vst1q_u8(comp_pred_ptr, blend_u8);
+
+        src0_ptr += 16;
+        src1_ptr += 16;
+        mask_ptr += 16;
+        comp_pred_ptr += 16;
+        w -= 16;
+      } while (w != 0);
+
+      src0 += src_stride0;
+      src1 += src_stride1;
+      mask += mask_stride;
+      comp_pred += width;
+    } while (--height != 0);
+  } else if (width == 8) {
+    const uint8x8_t max_alpha = vdup_n_u8(AOM_BLEND_A64_MAX_ALPHA);
+
+    do {
+      const uint8x8_t s0 = vld1_u8(src0);
+      const uint8x8_t s1 = vld1_u8(src1);
+      const uint8x8_t m0 = vld1_u8(mask);
+
+      uint8x8_t m0_inv = vsub_u8(max_alpha, m0);
+      uint16x8_t blend_u16 = vmull_u8(s0, m0);
+      blend_u16 = vmlal_u8(blend_u16, s1, m0_inv);
+      uint8x8_t blend_u8 = vrshrn_n_u16(blend_u16, AOM_BLEND_A64_ROUND_BITS);
+
+      vst1_u8(comp_pred, blend_u8);
+
+      src0 += src_stride0;
+      src1 += src_stride1;
+      mask += mask_stride;
+      comp_pred += 8;
+    } while (--height != 0);
+  } else {
+    const uint8x8_t max_alpha = vdup_n_u8(AOM_BLEND_A64_MAX_ALPHA);
+    int h = height / 2;
+    assert(width == 4);
+
+    do {
+      const uint8x8_t s0 = load_unaligned_u8(src0, src_stride0);
+      const uint8x8_t s1 = load_unaligned_u8(src1, src_stride1);
+      const uint8x8_t m0 = load_unaligned_u8(mask, mask_stride);
+
+      uint8x8_t m0_inv = vsub_u8(max_alpha, m0);
+      uint16x8_t blend_u16 = vmull_u8(s0, m0);
+      blend_u16 = vmlal_u8(blend_u16, s1, m0_inv);
+      uint8x8_t blend_u8 = vrshrn_n_u16(blend_u16, AOM_BLEND_A64_ROUND_BITS);
+
+      vst1_u8(comp_pred, blend_u8);
+
+      src0 += 2 * src_stride0;
+      src1 += 2 * src_stride1;
+      mask += 2 * mask_stride;
+      comp_pred += 8;
+    } while (--h != 0);
+  }
+}
diff --git a/test/comp_mask_variance_test.cc b/test/comp_mask_variance_test.cc
index a83a885..51fcf84 100644
--- a/test/comp_mask_variance_test.cc
+++ b/test/comp_mask_variance_test.cc
@@ -39,7 +39,7 @@
                                    int width, int height, const uint8_t *ref,
                                    int ref_stride);
 
-#if HAVE_SSSE3 || HAVE_SSE2 || HAVE_AVX2
+#if HAVE_SSSE3 || HAVE_SSE2 || HAVE_AVX2 || HAVE_NEON
 const BLOCK_SIZE kCompMaskPredParams[] = {
   BLOCK_8X8,   BLOCK_8X16, BLOCK_8X32,  BLOCK_16X8, BLOCK_16X16,
   BLOCK_16X32, BLOCK_32X8, BLOCK_32X16, BLOCK_32X32
@@ -179,6 +179,13 @@
                        ::testing::ValuesIn(kCompMaskPredParams)));
 #endif
 
+#if HAVE_NEON
+INSTANTIATE_TEST_SUITE_P(
+    NEON, AV1CompMaskVarianceTest,
+    ::testing::Combine(::testing::Values(&aom_comp_mask_pred_neon),
+                       ::testing::ValuesIn(kCompMaskPredParams)));
+#endif
+
 #ifndef aom_comp_mask_pred
 // can't run this test if aom_comp_mask_pred is defined to aom_comp_mask_pred_c
 class AV1CompMaskUpVarianceTest : public AV1CompMaskVarianceTest {