Add aom_comp_mask_pred_avx2

1. Add AVX2 implementation of aom_comp_mask_pred.
2. For width 8 still use ssse3 version.
3. For other widths(16,32), AVX2 version is 1.2x-2.0x faster
than ssse3 version

Change-Id: I80acc1be54ab21a52f7847e91b1299853add757c
diff --git a/aom_dsp/aom_dsp.cmake b/aom_dsp/aom_dsp.cmake
index 699f8b4..f61af74 100644
--- a/aom_dsp/aom_dsp.cmake
+++ b/aom_dsp/aom_dsp.cmake
@@ -392,6 +392,7 @@
       set(AOM_DSP_ENCODER_INTRIN_SSSE3
           ${AOM_DSP_ENCODER_INTRIN_SSSE3}
           "${AOM_ROOT}/aom_dsp/x86/masked_sad_intrin_ssse3.c"
+          "${AOM_ROOT}/aom_dsp/x86/masked_variance_intrin_ssse3.h"
           "${AOM_ROOT}/aom_dsp/x86/masked_variance_intrin_ssse3.c")
 
       set(AOM_DSP_ENCODER_INTRIN_SSE2
diff --git a/aom_dsp/aom_dsp_rtcd_defs.pl b/aom_dsp/aom_dsp_rtcd_defs.pl
index cde4d1d..326fd6c 100755
--- a/aom_dsp/aom_dsp_rtcd_defs.pl
+++ b/aom_dsp/aom_dsp_rtcd_defs.pl
@@ -1784,7 +1784,7 @@
 
 
   add_proto qw/void aom_comp_mask_pred/, "uint8_t *comp_pred, const uint8_t *pred, int width, int height, const uint8_t *ref, int ref_stride, const uint8_t *mask, int mask_stride, int invert_mask";
-  specialize qw/aom_comp_mask_pred ssse3/;
+  specialize qw/aom_comp_mask_pred ssse3 avx2/;
 
   add_proto qw/void aom_highbd_comp_mask_pred/, "uint16_t *comp_pred, const uint8_t *pred8, int width, int height, const uint8_t *ref8, int ref_stride, const uint8_t *mask, int mask_stride, int invert_mask";
   add_proto qw/void aom_highbd_comp_mask_upsampled_pred/, "uint16_t *comp_pred, const uint8_t *pred8, int width, int height, int subsample_x_q3, int subsample_y_q3, const uint8_t *ref8, int ref_stride, const uint8_t *mask, int mask_stride, int invert_mask, int bd";
diff --git a/aom_dsp/x86/masked_variance_intrin_ssse3.c b/aom_dsp/x86/masked_variance_intrin_ssse3.c
index 468766c..9ae1de6 100644
--- a/aom_dsp/x86/masked_variance_intrin_ssse3.c
+++ b/aom_dsp/x86/masked_variance_intrin_ssse3.c
@@ -15,11 +15,12 @@
 
 #include "./aom_config.h"
 #include "./aom_dsp_rtcd.h"
-#include "aom_dsp/blend.h"
 #include "aom/aom_integer.h"
-#include "aom_ports/mem.h"
 #include "aom_dsp/aom_filter.h"
+#include "aom_dsp/blend.h"
+#include "aom_dsp/x86/masked_variance_intrin_ssse3.h"
 #include "aom_dsp/x86/synonyms.h"
+#include "aom_ports/mem.h"
 
 // For width a multiple of 16
 static void bilinear_filter(const uint8_t *src, int src_stride, int xoffset,
@@ -1040,32 +1041,6 @@
   *sse = _mm_cvtsi128_si32(_mm_srli_si128(sum, 4));
 }
 
-static INLINE void comp_mask_pred_16_ssse3(const uint8_t *src0,
-                                           const uint8_t *src1,
-                                           const uint8_t *mask, uint8_t *dst) {
-  const __m128i alpha_max = _mm_set1_epi8(AOM_BLEND_A64_MAX_ALPHA);
-  const __m128i round_offset =
-      _mm_set1_epi16(1 << (15 - AOM_BLEND_A64_ROUND_BITS));
-
-  const __m128i sA0 = _mm_lddqu_si128((const __m128i *)(src0));
-  const __m128i sA1 = _mm_lddqu_si128((const __m128i *)(src1));
-  const __m128i aA = _mm_load_si128((const __m128i *)(mask));
-
-  const __m128i maA = _mm_sub_epi8(alpha_max, aA);
-
-  const __m128i ssAL = _mm_unpacklo_epi8(sA0, sA1);
-  const __m128i aaAL = _mm_unpacklo_epi8(aA, maA);
-  const __m128i ssAH = _mm_unpackhi_epi8(sA0, sA1);
-  const __m128i aaAH = _mm_unpackhi_epi8(aA, maA);
-
-  const __m128i blendAL = _mm_maddubs_epi16(ssAL, aaAL);
-  const __m128i blendAH = _mm_maddubs_epi16(ssAH, aaAH);
-
-  const __m128i roundAL = _mm_mulhrs_epi16(blendAL, round_offset);
-  const __m128i roundAH = _mm_mulhrs_epi16(blendAH, round_offset);
-  _mm_store_si128((__m128i *)dst, _mm_packus_epi16(roundAL, roundAH));
-}
-
 void aom_comp_mask_pred_ssse3(uint8_t *comp_pred, const uint8_t *pred,
                               int width, int height, const uint8_t *ref,
                               int ref_stride, const uint8_t *mask,
@@ -1074,46 +1049,11 @@
   const uint8_t *src1 = invert_mask ? ref : pred;
   const int stride0 = invert_mask ? width : ref_stride;
   const int stride1 = invert_mask ? ref_stride : width;
-
-  const __m128i alpha_max = _mm_set1_epi8(AOM_BLEND_A64_MAX_ALPHA);
-  const __m128i round_offset =
-      _mm_set1_epi16(1 << (15 - AOM_BLEND_A64_ROUND_BITS));
   assert(height % 2 == 0);
-  assert(width % 8 == 0);
   int i = 0;
   if (width == 8) {
-    do {
-      // odd line A
-      const __m128i sA0 = _mm_loadl_epi64((const __m128i *)(src0));
-      const __m128i sA1 = _mm_loadl_epi64((const __m128i *)(src1));
-      const __m128i aA = _mm_loadl_epi64((const __m128i *)(mask));
-
-      // even line B
-      const __m128i sB0 = _mm_loadl_epi64((const __m128i *)(src0 + stride0));
-      const __m128i sB1 = _mm_loadl_epi64((const __m128i *)(src1 + stride1));
-      const __m128i a = _mm_castps_si128(_mm_loadh_pi(
-          _mm_castsi128_ps(aA), (const __m64 *)(mask + mask_stride)));
-
-      const __m128i ssA = _mm_unpacklo_epi8(sA0, sA1);
-      const __m128i ssB = _mm_unpacklo_epi8(sB0, sB1);
-
-      const __m128i ma = _mm_sub_epi8(alpha_max, a);
-      const __m128i aaA = _mm_unpacklo_epi8(a, ma);
-      const __m128i aaB = _mm_unpackhi_epi8(a, ma);
-
-      const __m128i blendA = _mm_maddubs_epi16(ssA, aaA);
-      const __m128i blendB = _mm_maddubs_epi16(ssB, aaB);
-      const __m128i roundA = _mm_mulhrs_epi16(blendA, round_offset);
-      const __m128i roundB = _mm_mulhrs_epi16(blendB, round_offset);
-      const __m128i round = _mm_packus_epi16(roundA, roundB);
-      // comp_pred's stride == width == 8
-      _mm_store_si128((__m128i *)(comp_pred), round);
-      comp_pred += (width << 1);
-      src0 += (stride0 << 1);
-      src1 += (stride1 << 1);
-      mask += (mask_stride << 1);
-      i += 2;
-    } while (i < height);
+    comp_mask_pred_8_ssse3(comp_pred, height, src0, stride0, src1, stride1,
+                           mask, mask_stride);
   } else if (width == 16) {
     do {
       comp_mask_pred_16_ssse3(src0, src1, mask, comp_pred);
@@ -1126,6 +1066,7 @@
       i += 2;
     } while (i < height);
   } else {  // width == 32
+    assert(width == 32);
     do {
       comp_mask_pred_16_ssse3(src0, src1, mask, comp_pred);
       comp_mask_pred_16_ssse3(src0 + 16, src1 + 16, mask + 16, comp_pred + 16);
diff --git a/aom_dsp/x86/masked_variance_intrin_ssse3.h b/aom_dsp/x86/masked_variance_intrin_ssse3.h
new file mode 100644
index 0000000..4d058a1
--- /dev/null
+++ b/aom_dsp/x86/masked_variance_intrin_ssse3.h
@@ -0,0 +1,91 @@
+/*
+ * Copyright (c) 2018, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#ifndef _AOM_DSP_X86_MASKED_VARIANCE_INTRIN_SSSE3_H
+#define _AOM_DSP_X86_MASKED_VARIANCE_INTRIN_SSSE3_H
+
+#include <stdlib.h>
+#include <string.h>
+#include <tmmintrin.h>
+
+#include "./aom_config.h"
+#include "./aom_dsp_rtcd.h"
+#include "aom_dsp/blend.h"
+
+static INLINE void comp_mask_pred_16_ssse3(const uint8_t *src0,
+                                           const uint8_t *src1,
+                                           const uint8_t *mask, uint8_t *dst) {
+  const __m128i alpha_max = _mm_set1_epi8(AOM_BLEND_A64_MAX_ALPHA);
+  const __m128i round_offset =
+      _mm_set1_epi16(1 << (15 - AOM_BLEND_A64_ROUND_BITS));
+
+  const __m128i sA0 = _mm_lddqu_si128((const __m128i *)(src0));
+  const __m128i sA1 = _mm_lddqu_si128((const __m128i *)(src1));
+  const __m128i aA = _mm_load_si128((const __m128i *)(mask));
+
+  const __m128i maA = _mm_sub_epi8(alpha_max, aA);
+
+  const __m128i ssAL = _mm_unpacklo_epi8(sA0, sA1);
+  const __m128i aaAL = _mm_unpacklo_epi8(aA, maA);
+  const __m128i ssAH = _mm_unpackhi_epi8(sA0, sA1);
+  const __m128i aaAH = _mm_unpackhi_epi8(aA, maA);
+
+  const __m128i blendAL = _mm_maddubs_epi16(ssAL, aaAL);
+  const __m128i blendAH = _mm_maddubs_epi16(ssAH, aaAH);
+
+  const __m128i roundAL = _mm_mulhrs_epi16(blendAL, round_offset);
+  const __m128i roundAH = _mm_mulhrs_epi16(blendAH, round_offset);
+  _mm_store_si128((__m128i *)dst, _mm_packus_epi16(roundAL, roundAH));
+}
+
+static INLINE void comp_mask_pred_8_ssse3(uint8_t *comp_pred, int height,
+                                          const uint8_t *src0, int stride0,
+                                          const uint8_t *src1, int stride1,
+                                          const uint8_t *mask,
+                                          int mask_stride) {
+  int i = 0;
+  const __m128i alpha_max = _mm_set1_epi8(AOM_BLEND_A64_MAX_ALPHA);
+  const __m128i round_offset =
+      _mm_set1_epi16(1 << (15 - AOM_BLEND_A64_ROUND_BITS));
+  do {
+    // odd line A
+    const __m128i sA0 = _mm_loadl_epi64((const __m128i *)(src0));
+    const __m128i sA1 = _mm_loadl_epi64((const __m128i *)(src1));
+    const __m128i aA = _mm_loadl_epi64((const __m128i *)(mask));
+    // even line B
+    const __m128i sB0 = _mm_loadl_epi64((const __m128i *)(src0 + stride0));
+    const __m128i sB1 = _mm_loadl_epi64((const __m128i *)(src1 + stride1));
+    const __m128i a = _mm_castps_si128(_mm_loadh_pi(
+        _mm_castsi128_ps(aA), (const __m64 *)(mask + mask_stride)));
+
+    const __m128i ssA = _mm_unpacklo_epi8(sA0, sA1);
+    const __m128i ssB = _mm_unpacklo_epi8(sB0, sB1);
+
+    const __m128i ma = _mm_sub_epi8(alpha_max, a);
+    const __m128i aaA = _mm_unpacklo_epi8(a, ma);
+    const __m128i aaB = _mm_unpackhi_epi8(a, ma);
+
+    const __m128i blendA = _mm_maddubs_epi16(ssA, aaA);
+    const __m128i blendB = _mm_maddubs_epi16(ssB, aaB);
+    const __m128i roundA = _mm_mulhrs_epi16(blendA, round_offset);
+    const __m128i roundB = _mm_mulhrs_epi16(blendB, round_offset);
+    const __m128i round = _mm_packus_epi16(roundA, roundB);
+    // comp_pred's stride == width == 8
+    _mm_store_si128((__m128i *)(comp_pred), round);
+    comp_pred += (8 << 1);
+    src0 += (stride0 << 1);
+    src1 += (stride1 << 1);
+    mask += (mask_stride << 1);
+    i += 2;
+  } while (i < height);
+}
+
+#endif
diff --git a/aom_dsp/x86/variance_avx2.c b/aom_dsp/x86/variance_avx2.c
index 18a70df..a041bba 100644
--- a/aom_dsp/x86/variance_avx2.c
+++ b/aom_dsp/x86/variance_avx2.c
@@ -11,6 +11,7 @@
 
 #include <immintrin.h>
 #include "./aom_dsp_rtcd.h"
+#include "aom_dsp/x86/masked_variance_intrin_ssse3.h"
 
 typedef void (*get_var_avx2)(const uint8_t *src, int src_stride,
                              const uint8_t *ref, int ref_stride,
@@ -190,3 +191,87 @@
   _mm256_zeroupper();
   return variance;
 }
+
+static INLINE __m256i mm256_loadu2(const uint8_t *p0, const uint8_t *p1) {
+  const __m256i d =
+      _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)p1));
+  return _mm256_insertf128_si256(d, _mm_loadu_si128((const __m128i *)p0), 1);
+}
+
+static INLINE void comp_mask_pred_line_avx2(const __m256i s0, const __m256i s1,
+                                            const __m256i a,
+                                            uint8_t *comp_pred) {
+  const __m256i alpha_max = _mm256_set1_epi8(AOM_BLEND_A64_MAX_ALPHA);
+  const int16_t round_bits = 15 - AOM_BLEND_A64_ROUND_BITS;
+  const __m256i round_offset = _mm256_set1_epi16(1 << (round_bits));
+
+  const __m256i ma = _mm256_sub_epi8(alpha_max, a);
+
+  const __m256i ssAL = _mm256_unpacklo_epi8(s0, s1);
+  const __m256i aaAL = _mm256_unpacklo_epi8(a, ma);
+  const __m256i ssAH = _mm256_unpackhi_epi8(s0, s1);
+  const __m256i aaAH = _mm256_unpackhi_epi8(a, ma);
+
+  const __m256i blendAL = _mm256_maddubs_epi16(ssAL, aaAL);
+  const __m256i blendAH = _mm256_maddubs_epi16(ssAH, aaAH);
+  const __m256i roundAL = _mm256_mulhrs_epi16(blendAL, round_offset);
+  const __m256i roundAH = _mm256_mulhrs_epi16(blendAH, round_offset);
+
+  const __m256i roundA = _mm256_packus_epi16(roundAL, roundAH);
+  _mm256_storeu_si256((__m256i *)(comp_pred), roundA);
+}
+
+void aom_comp_mask_pred_avx2(uint8_t *comp_pred, const uint8_t *pred, int width,
+                             int height, const uint8_t *ref, int ref_stride,
+                             const uint8_t *mask, int mask_stride,
+                             int invert_mask) {
+  int i = 0;
+  const uint8_t *src0 = invert_mask ? pred : ref;
+  const uint8_t *src1 = invert_mask ? ref : pred;
+  const int stride0 = invert_mask ? width : ref_stride;
+  const int stride1 = invert_mask ? ref_stride : width;
+  if (width == 8) {
+    comp_mask_pred_8_ssse3(comp_pred, height, src0, stride0, src1, stride1,
+                           mask, mask_stride);
+  } else if (width == 16) {
+    do {
+      const __m256i sA0 = mm256_loadu2(src0 + stride0, src0);
+      const __m256i sA1 = mm256_loadu2(src1 + stride1, src1);
+      const __m256i aA = mm256_loadu2(mask + mask_stride, mask);
+      src0 += (stride0 << 1);
+      src1 += (stride1 << 1);
+      mask += (mask_stride << 1);
+      const __m256i sB0 = mm256_loadu2(src0 + stride0, src0);
+      const __m256i sB1 = mm256_loadu2(src1 + stride1, src1);
+      const __m256i aB = mm256_loadu2(mask + mask_stride, mask);
+      src0 += (stride0 << 1);
+      src1 += (stride1 << 1);
+      mask += (mask_stride << 1);
+      // comp_pred's stride == width == 16
+      comp_mask_pred_line_avx2(sA0, sA1, aA, comp_pred);
+      comp_mask_pred_line_avx2(sB0, sB1, aB, comp_pred + 32);
+      comp_pred += (16 << 2);
+      i += 4;
+    } while (i < height);
+  } else {  // for width == 32
+    do {
+      const __m256i sA0 = _mm256_lddqu_si256((const __m256i *)(src0));
+      const __m256i sA1 = _mm256_lddqu_si256((const __m256i *)(src1));
+      const __m256i aA = _mm256_lddqu_si256((const __m256i *)(mask));
+
+      const __m256i sB0 = _mm256_lddqu_si256((const __m256i *)(src0 + stride0));
+      const __m256i sB1 = _mm256_lddqu_si256((const __m256i *)(src1 + stride1));
+      const __m256i aB =
+          _mm256_lddqu_si256((const __m256i *)(mask + mask_stride));
+
+      comp_mask_pred_line_avx2(sA0, sA1, aA, comp_pred);
+      comp_mask_pred_line_avx2(sB0, sB1, aB, comp_pred + 32);
+      comp_pred += (32 << 1);
+
+      src0 += (stride0 << 1);
+      src1 += (stride1 << 1);
+      mask += (mask_stride << 1);
+      i += 2;
+    } while (i < height);
+  }
+}
diff --git a/test/comp_mask_variance_test.cc b/test/comp_mask_variance_test.cc
index acb0e8c..b6fc75b 100644
--- a/test/comp_mask_variance_test.cc
+++ b/test/comp_mask_variance_test.cc
@@ -11,7 +11,6 @@
 
 #include <cstdlib>
 #include <new>
-#include <vector>
 
 #include "./aom_config.h"
 #include "./aom_dsp_rtcd.h"
@@ -28,15 +27,13 @@
 #include "test/util.h"
 #include "third_party/googletest/src/googletest/include/gtest/gtest.h"
 
-using std::vector;
-
 namespace AV1CompMaskVariance {
 typedef void (*comp_mask_pred_func)(uint8_t *comp_pred, const uint8_t *pred,
                                     int width, int height, const uint8_t *ref,
                                     int ref_stride, const uint8_t *mask,
                                     int mask_stride, int invert_mask);
 
-const BLOCK_SIZE valid_bsize[] = {
+const BLOCK_SIZE kValidBlockSize[] = {
   BLOCK_8X8,   BLOCK_8X16, BLOCK_8X32,  BLOCK_16X8,  BLOCK_16X16,
   BLOCK_16X32, BLOCK_32X8, BLOCK_32X16, BLOCK_32X32,
 };
@@ -53,12 +50,13 @@
  protected:
   void RunCheckOutput(comp_mask_pred_func test_impl, BLOCK_SIZE bsize, int inv);
   void RunSpeedTest(comp_mask_pred_func test_impl, BLOCK_SIZE bsize);
-  bool CheckResult(int w, int h) {
-    for (int i = 0; i < h; ++i) {
-      for (int j = 0; j < w; ++j) {
-        int idx = i * w + j;
+  bool CheckResult(int width, int height) {
+    for (int y = 0; y < height; ++y) {
+      for (int x = 0; x < width; ++x) {
+        const int idx = y * width + x;
         if (comp_pred1_[idx] != comp_pred2_[idx]) {
-          printf("%dx%d mismatch @%d(%d,%d) ", w, h, idx, i, j);
+          printf("%dx%d mismatch @%d(%d,%d) ", width, height, idx, y, x);
+          printf("%d != %d ", comp_pred1_[idx], comp_pred2_[idx]);
           return false;
         }
       }
@@ -160,7 +158,14 @@
 INSTANTIATE_TEST_CASE_P(
     SSSE3, AV1CompMaskVarianceTest,
     ::testing::Combine(::testing::Values(&aom_comp_mask_pred_ssse3),
-                       ::testing::ValuesIn(valid_bsize)));
+                       ::testing::ValuesIn(kValidBlockSize)));
+#endif
+
+#if HAVE_AVX2
+INSTANTIATE_TEST_CASE_P(
+    AVX2, AV1CompMaskVarianceTest,
+    ::testing::Combine(::testing::Values(&aom_comp_mask_pred_avx2),
+                       ::testing::ValuesIn(kValidBlockSize)));
 #endif
 
 #ifndef aom_comp_mask_pred
@@ -249,7 +254,15 @@
 INSTANTIATE_TEST_CASE_P(
     SSSE3, AV1CompMaskUpVarianceTest,
     ::testing::Combine(::testing::Values(&aom_comp_mask_pred_ssse3),
-                       ::testing::ValuesIn(valid_bsize)));
+                       ::testing::ValuesIn(kValidBlockSize)));
 #endif
+
+#if HAVE_AVX2
+INSTANTIATE_TEST_CASE_P(
+    AVX2, AV1CompMaskUpVarianceTest,
+    ::testing::Combine(::testing::Values(&aom_comp_mask_pred_avx2),
+                       ::testing::ValuesIn(kValidBlockSize)));
 #endif
+
+#endif  // ifndef aom_comp_mask_pred
 }  // namespace AV1CompMaskVariance