SSE2 optimization of aom_highbd_comp_mask_pred

SSE2 optimization of aom_highbd_comp_mask_pred_c
has been added.

aom_highbd_comp_mask_pred_sse2: ~ 4.3 times faster
than its c implementation.

Change-Id: Ia7a14241826223c847fc19a7a4209dce67e165b9
diff --git a/aom_dsp/aom_dsp_rtcd_defs.pl b/aom_dsp/aom_dsp_rtcd_defs.pl
index 84841c3..16512c5 100755
--- a/aom_dsp/aom_dsp_rtcd_defs.pl
+++ b/aom_dsp/aom_dsp_rtcd_defs.pl
@@ -1540,7 +1540,7 @@
   specialize qw/aom_comp_mask_pred ssse3 avx2/;
 
   add_proto qw/void aom_highbd_comp_mask_pred/, "uint16_t *comp_pred, const uint8_t *pred8, int width, int height, const uint8_t *ref8, int ref_stride, const uint8_t *mask, int mask_stride, int invert_mask";
-  specialize qw/aom_highbd_comp_mask_pred avx2/;
+  specialize qw/aom_highbd_comp_mask_pred sse2 avx2/;
 
 }  # CONFIG_AV1_ENCODER
 
diff --git a/aom_dsp/x86/variance_sse2.c b/aom_dsp/x86/variance_sse2.c
index 7e3c5d5..9efddb9 100644
--- a/aom_dsp/x86/variance_sse2.c
+++ b/aom_dsp/x86/variance_sse2.c
@@ -16,6 +16,7 @@
 #include "config/aom_dsp_rtcd.h"
 #include "config/av1_rtcd.h"
 
+#include "aom_dsp/blend.h"
 #include "aom_dsp/x86/synonyms.h"
 
 #include "aom_ports/mem.h"
@@ -664,3 +665,110 @@
     pred += 16;
   }
 }
+
+static INLINE __m128i highbd_comp_mask_pred_line_sse2(const __m128i s0,
+                                                      const __m128i s1,
+                                                      const __m128i a) {
+  const __m128i alpha_max = _mm_set1_epi16((1 << AOM_BLEND_A64_ROUND_BITS));
+  const __m128i round_const =
+      _mm_set1_epi32((1 << AOM_BLEND_A64_ROUND_BITS) >> 1);
+  const __m128i a_inv = _mm_sub_epi16(alpha_max, a);
+
+  const __m128i s_lo = _mm_unpacklo_epi16(s0, s1);
+  const __m128i a_lo = _mm_unpacklo_epi16(a, a_inv);
+  const __m128i pred_lo = _mm_madd_epi16(s_lo, a_lo);
+  const __m128i pred_l = _mm_srai_epi32(_mm_add_epi32(pred_lo, round_const),
+                                        AOM_BLEND_A64_ROUND_BITS);
+
+  const __m128i s_hi = _mm_unpackhi_epi16(s0, s1);
+  const __m128i a_hi = _mm_unpackhi_epi16(a, a_inv);
+  const __m128i pred_hi = _mm_madd_epi16(s_hi, a_hi);
+  const __m128i pred_h = _mm_srai_epi32(_mm_add_epi32(pred_hi, round_const),
+                                        AOM_BLEND_A64_ROUND_BITS);
+
+  const __m128i comp = _mm_packs_epi32(pred_l, pred_h);
+
+  return comp;
+}
+
+void aom_highbd_comp_mask_pred_sse2(uint16_t *comp_pred, const uint8_t *pred8,
+                                    int width, int height, const uint8_t *ref8,
+                                    int ref_stride, const uint8_t *mask,
+                                    int mask_stride, int invert_mask) {
+  int i = 0;
+  uint16_t *pred = CONVERT_TO_SHORTPTR(pred8);
+  uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
+  const uint16_t *src0 = invert_mask ? pred : ref;
+  const uint16_t *src1 = invert_mask ? ref : pred;
+  const int stride0 = invert_mask ? width : ref_stride;
+  const int stride1 = invert_mask ? ref_stride : width;
+  const __m128i zero = _mm_setzero_si128();
+
+  if (width == 8) {
+    do {
+      const __m128i s0 = _mm_loadu_si128((const __m128i *)(src0));
+      const __m128i s1 = _mm_loadu_si128((const __m128i *)(src1));
+      const __m128i m_8 = _mm_loadl_epi64((const __m128i *)mask);
+      const __m128i m_16 = _mm_unpacklo_epi8(m_8, zero);
+
+      const __m128i comp = highbd_comp_mask_pred_line_sse2(s0, s1, m_16);
+
+      _mm_storeu_si128((__m128i *)comp_pred, comp);
+
+      src0 += stride0;
+      src1 += stride1;
+      mask += mask_stride;
+      comp_pred += width;
+      i += 1;
+    } while (i < height);
+  } else if (width == 16) {
+    do {
+      const __m128i s0 = _mm_loadu_si128((const __m128i *)(src0));
+      const __m128i s2 = _mm_loadu_si128((const __m128i *)(src0 + 8));
+      const __m128i s1 = _mm_loadu_si128((const __m128i *)(src1));
+      const __m128i s3 = _mm_loadu_si128((const __m128i *)(src1 + 8));
+
+      const __m128i m_8 = _mm_loadu_si128((const __m128i *)mask);
+      const __m128i m01_16 = _mm_unpacklo_epi8(m_8, zero);
+      const __m128i m23_16 = _mm_unpackhi_epi8(m_8, zero);
+
+      const __m128i comp = highbd_comp_mask_pred_line_sse2(s0, s1, m01_16);
+      const __m128i comp1 = highbd_comp_mask_pred_line_sse2(s2, s3, m23_16);
+
+      _mm_storeu_si128((__m128i *)comp_pred, comp);
+      _mm_storeu_si128((__m128i *)(comp_pred + 8), comp1);
+
+      src0 += stride0;
+      src1 += stride1;
+      mask += mask_stride;
+      comp_pred += width;
+      i += 1;
+    } while (i < height);
+  } else if (width == 32) {
+    do {
+      for (int j = 0; j < 2; j++) {
+        const __m128i s0 = _mm_loadu_si128((const __m128i *)(src0 + j * 16));
+        const __m128i s2 =
+            _mm_loadu_si128((const __m128i *)(src0 + 8 + j * 16));
+        const __m128i s1 = _mm_loadu_si128((const __m128i *)(src1 + j * 16));
+        const __m128i s3 =
+            _mm_loadu_si128((const __m128i *)(src1 + 8 + j * 16));
+
+        const __m128i m_8 = _mm_loadu_si128((const __m128i *)(mask + j * 16));
+        const __m128i m01_16 = _mm_unpacklo_epi8(m_8, zero);
+        const __m128i m23_16 = _mm_unpackhi_epi8(m_8, zero);
+
+        const __m128i comp = highbd_comp_mask_pred_line_sse2(s0, s1, m01_16);
+        const __m128i comp1 = highbd_comp_mask_pred_line_sse2(s2, s3, m23_16);
+
+        _mm_storeu_si128((__m128i *)(comp_pred + j * 16), comp);
+        _mm_storeu_si128((__m128i *)(comp_pred + 8 + j * 16), comp1);
+      }
+      src0 += stride0;
+      src1 += stride1;
+      mask += mask_stride;
+      comp_pred += width;
+      i += 1;
+    } while (i < height);
+  }
+}
diff --git a/test/comp_mask_variance_test.cc b/test/comp_mask_variance_test.cc
index 0016ddd..2d842c9 100644
--- a/test/comp_mask_variance_test.cc
+++ b/test/comp_mask_variance_test.cc
@@ -428,6 +428,14 @@
                        ::testing::Range(8, 13, 2)));
 #endif
 
+#if HAVE_SSE2
+INSTANTIATE_TEST_CASE_P(
+    SSE2, AV1HighbdCompMaskVarianceTest,
+    ::testing::Combine(::testing::Values(&aom_highbd_comp_mask_pred_sse2),
+                       ::testing::ValuesIn(kValidBlockSize),
+                       ::testing::Range(8, 13, 2)));
+#endif
+
 #ifndef aom_highbd_comp_mask_pred
 // can't run this test if aom_highbd_comp_mask_pred is defined to
 // aom_highbd_comp_mask_pred_c
@@ -540,5 +548,13 @@
                        ::testing::Range(8, 13, 2)));
 #endif
 
+#if HAVE_SSE2
+INSTANTIATE_TEST_CASE_P(
+    SSE2, AV1HighbdCompMaskUpVarianceTest,
+    ::testing::Combine(::testing::Values(&aom_highbd_comp_mask_pred_sse2),
+                       ::testing::ValuesIn(kValidBlockSize),
+                       ::testing::Range(8, 13, 2)));
+#endif
+
 #endif  // ifndef aom_highbd_comp_mask_pred
 }  // namespace AV1CompMaskVariance