SSE optimization of build compound mask function

SSE4_1 optimization of av1_build_compound_diffwtd_mask_d16_c
has been added.

Unit tests have also been added for
av1_build_compound_diffwtd_mask_d16_sse4_1.

av1_build_compound_diffwtd_mask_d16_sse4_1: ~ 6 times faster
than its C implementation.

Change-Id: I5d8f9a5f2820c1a1f7ee1d2a8e219af20d4c4701
diff --git a/aom_dsp/blend.h b/aom_dsp/blend.h
index e5297ff..434bb83 100644
--- a/aom_dsp/blend.h
+++ b/aom_dsp/blend.h
@@ -39,4 +39,7 @@
 // Blending by averaging.
 #define AOM_BLEND_AVG(v0, v1) ROUND_POWER_OF_TWO((v0) + (v1), 1)
 
+#define DIFF_FACTOR_LOG2 4
+#define DIFF_FACTOR (1 << DIFF_FACTOR_LOG2)
+
 #endif  // AOM_DSP_BLEND_H_
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index 2fd4303..8fbd2c1 100755
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -148,6 +148,8 @@
 add_proto qw/void av1_build_compound_diffwtd_mask/, "uint8_t *mask, DIFFWTD_MASK_TYPE mask_type, const uint8_t *src0, int src0_stride, const uint8_t *src1, int src1_stride, int h, int w";
 add_proto qw/void av1_build_compound_diffwtd_mask_highbd/, "uint8_t *mask, DIFFWTD_MASK_TYPE mask_type, const uint8_t *src0, int src0_stride, const uint8_t *src1, int src1_stride, int h, int w, int bd";
 specialize qw/av1_build_compound_diffwtd_mask sse4_1/;
+add_proto qw/void av1_build_compound_diffwtd_mask_d16/, "uint8_t *mask, DIFFWTD_MASK_TYPE mask_type, const CONV_BUF_TYPE *src0, int src0_stride, const CONV_BUF_TYPE *src1, int src1_stride, int h, int w, ConvolveParams *conv_params, int bd";
+specialize qw/av1_build_compound_diffwtd_mask_d16 sse4_1/;
 
 #
 # Encoder functions below this point.
diff --git a/av1/common/reconinter.c b/av1/common/reconinter.c
index 5b78bf5..fb63683 100644
--- a/av1/common/reconinter.c
+++ b/av1/common/reconinter.c
@@ -301,9 +301,7 @@
   }
 }
 
-#define DIFF_FACTOR 16
-
-static void diffwtd_mask_d32(uint8_t *mask, int which_inverse, int mask_base,
+static void diffwtd_mask_d16(uint8_t *mask, int which_inverse, int mask_base,
                              const CONV_BUF_TYPE *src0, int src0_stride,
                              const CONV_BUF_TYPE *src1, int src1_stride, int h,
                              int w, ConvolveParams *conv_params, int bd) {
@@ -320,17 +318,17 @@
   }
 }
 
-static void build_compound_diffwtd_mask_d16(
+void av1_build_compound_diffwtd_mask_d16_c(
     uint8_t *mask, DIFFWTD_MASK_TYPE mask_type, const CONV_BUF_TYPE *src0,
     int src0_stride, const CONV_BUF_TYPE *src1, int src1_stride, int h, int w,
     ConvolveParams *conv_params, int bd) {
   switch (mask_type) {
     case DIFFWTD_38:
-      diffwtd_mask_d32(mask, 0, 38, src0, src0_stride, src1, src1_stride, h, w,
+      diffwtd_mask_d16(mask, 0, 38, src0, src0_stride, src1, src1_stride, h, w,
                        conv_params, bd);
       break;
     case DIFFWTD_38_INV:
-      diffwtd_mask_d32(mask, 1, 38, src0, src0_stride, src1, src1_stride, h, w,
+      diffwtd_mask_d16(mask, 1, 38, src0, src0_stride, src1, src1_stride, h, w,
                        conv_params, bd);
       break;
     default: assert(0);
@@ -632,9 +630,9 @@
                            xd, can_use_previous);
 
   if (!plane && comp_data.interinter_compound_type == COMPOUND_DIFFWTD) {
-    build_compound_diffwtd_mask_d16(comp_data.seg_mask, comp_data.mask_type,
-                                    org_dst, org_dst_stride, tmp_buf16,
-                                    tmp_buf_stride, h, w, conv_params, xd->bd);
+    av1_build_compound_diffwtd_mask_d16(
+        comp_data.seg_mask, comp_data.mask_type, org_dst, org_dst_stride,
+        tmp_buf16, tmp_buf_stride, h, w, conv_params, xd->bd);
   }
   build_masked_compound_no_round(dst, dst_stride, org_dst, org_dst_stride,
                                  tmp_buf16, tmp_buf_stride, &comp_data,
diff --git a/av1/common/x86/reconinter_sse4.c b/av1/common/x86/reconinter_sse4.c
index 9ccd2a0..5171ca4 100644
--- a/av1/common/x86/reconinter_sse4.c
+++ b/av1/common/x86/reconinter_sse4.c
@@ -93,3 +93,61 @@
     } while (i < h);
   }
 }
+
+void av1_build_compound_diffwtd_mask_d16_sse4_1(
+    uint8_t *mask, DIFFWTD_MASK_TYPE mask_type, const CONV_BUF_TYPE *src0,
+    int src0_stride, const CONV_BUF_TYPE *src1, int src1_stride, int h, int w,
+    ConvolveParams *conv_params, int bd) {
+  const int which_inverse = (mask_type == DIFFWTD_38) ? 0 : 1;
+  const int mask_base = 38;
+  int round =
+      2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1 + (bd - 8);
+  const __m128i round_const = _mm_set1_epi16((1 << round) >> 1);
+  const __m128i mask_base_16 = _mm_set1_epi16(mask_base);
+  const __m128i clip_diff = _mm_set1_epi16(AOM_BLEND_A64_MAX_ALPHA);
+  const __m128i add_const =
+      _mm_set1_epi16((which_inverse ? AOM_BLEND_A64_MAX_ALPHA : 0));
+  const __m128i add_sign = _mm_set1_epi16((which_inverse ? -1 : 1));
+
+  int i, j;
+  // When rounding constant is added, there is a possibility of overflow.
+  // However that much precision is not required. Code should very well work for
+  // other values of DIFF_FACTOR_LOG2 and AOM_BLEND_A64_MAX_ALPHA as well. But
+  // there is a possibility of corner case bugs.
+  assert(DIFF_FACTOR_LOG2 == 4);
+  assert(AOM_BLEND_A64_MAX_ALPHA == 64);
+  for (i = 0; i < h; ++i) {
+    for (j = 0; j < w; j += 8) {
+      const __m128i data_src0 =
+          _mm_loadu_si128((__m128i *)&src0[(i * src0_stride) + j]);
+      const __m128i data_src1 =
+          _mm_loadu_si128((__m128i *)&src1[(i * src1_stride) + j]);
+
+      const __m128i diffa = _mm_subs_epu16(data_src0, data_src1);
+      const __m128i diffb = _mm_subs_epu16(data_src1, data_src0);
+      const __m128i diff = _mm_max_epu16(diffa, diffb);
+      const __m128i diff_round =
+          _mm_srli_epi16(_mm_adds_epu16(diff, round_const), round);
+      const __m128i diff_factor = _mm_srli_epi16(diff_round, DIFF_FACTOR_LOG2);
+      const __m128i diff_mask = _mm_adds_epi16(diff_factor, mask_base_16);
+      __m128i diff_clamp = _mm_min_epi16(diff_mask, clip_diff);
+      // clamp to 0 can be skipped since we are using add and saturate
+      // instruction
+
+      const __m128i diff_sign = _mm_sign_epi16(diff_clamp, add_sign);
+      const __m128i diff_const_16 = _mm_add_epi16(diff_sign, add_const);
+
+      // 8 bit conversion and saturation to uint8
+      const __m128i res_8 = _mm_packus_epi16(diff_const_16, diff_const_16);
+
+      // Store values into the destination buffer
+      __m128i *const dst = (__m128i *)&mask[i * w + j];
+
+      if ((w - j) > 4) {
+        _mm_storel_epi64(dst, res_8);
+      } else {  // w==4
+        *(uint32_t *)dst = _mm_cvtsi128_si32(res_8);
+      }
+    }
+  }
+}
diff --git a/test/reconinter_test.cc b/test/reconinter_test.cc
index 7c11d98..1161943 100644
--- a/test/reconinter_test.cc
+++ b/test/reconinter_test.cc
@@ -39,6 +39,124 @@
   ACMRandom rnd_;
 };
 
+typedef void (*buildcompdiffwtdmaskd16_func)(
+    uint8_t *mask, DIFFWTD_MASK_TYPE mask_type, const CONV_BUF_TYPE *src0,
+    int src0_stride, const CONV_BUF_TYPE *src1, int src1_stride, int h, int w,
+    ConvolveParams *conv_params, int bd);
+
+typedef ::testing::tuple<int, buildcompdiffwtdmaskd16_func, BLOCK_SIZE>
+    BuildCompDiffwtdMaskD16Param;
+
+::testing::internal::ParamGenerator<BuildCompDiffwtdMaskD16Param> BuildParams(
+    buildcompdiffwtdmaskd16_func filter) {
+  return ::testing::Combine(::testing::Range(8, 13, 2),
+                            ::testing::Values(filter),
+                            ::testing::Range(BLOCK_4X4, BLOCK_SIZES_ALL));
+}
+
+class BuildCompDiffwtdMaskD16Test
+    : public ::testing::TestWithParam<BuildCompDiffwtdMaskD16Param> {
+ public:
+  ~BuildCompDiffwtdMaskD16Test() {}
+  virtual void TearDown() { libaom_test::ClearSystemState(); }
+  void SetUp() { rnd_.Reset(ACMRandom::DeterministicSeed()); }
+
+ protected:
+  void RunCheckOutput(buildcompdiffwtdmaskd16_func test_impl);
+  void RunSpeedTest(buildcompdiffwtdmaskd16_func test_impl);
+  libaom_test::ACMRandom rnd_;
+};  // class BuildCompDiffwtdMaskD16Test
+
+void BuildCompDiffwtdMaskD16Test::RunCheckOutput(
+    buildcompdiffwtdmaskd16_func test_impl) {
+  const int block_idx = GET_PARAM(2);
+  const int bd = GET_PARAM(0);
+  const int width = block_size_wide[block_idx];
+  const int height = block_size_high[block_idx];
+  DECLARE_ALIGNED(16, uint8_t, mask_ref[2 * MAX_SB_SQUARE]);
+  DECLARE_ALIGNED(16, uint8_t, mask_test[2 * MAX_SB_SQUARE]);
+  DECLARE_ALIGNED(32, uint16_t, src0[MAX_SB_SQUARE]);
+  DECLARE_ALIGNED(32, uint16_t, src1[MAX_SB_SQUARE]);
+
+  ConvolveParams conv_params =
+      get_conv_params_no_round(0, 0, 0, NULL, 0, 1, bd);
+
+  int in_precision =
+      bd + 2 * FILTER_BITS - conv_params.round_0 - conv_params.round_1 + 2;
+
+  for (int i = 0; i < MAX_SB_SQUARE; i++) {
+    src0[i] = rnd_.Rand16() & ((1 << in_precision) - 1);
+    src1[i] = rnd_.Rand16() & ((1 << in_precision) - 1);
+  }
+
+  for (int mask_type = 0; mask_type < DIFFWTD_MASK_TYPES; mask_type++) {
+    av1_build_compound_diffwtd_mask_d16_c(
+        mask_ref, (DIFFWTD_MASK_TYPE)mask_type, src0, width, src1, width,
+        height, width, &conv_params, bd);
+
+    test_impl(mask_test, (DIFFWTD_MASK_TYPE)mask_type, src0, width, src1, width,
+              height, width, &conv_params, bd);
+
+    for (int r = 0; r < height; ++r) {
+      for (int c = 0; c < width; ++c) {
+        ASSERT_EQ(mask_ref[c + r * width], mask_test[c + r * width])
+            << "Mismatch at unit tests for BuildCompDiffwtdMaskD16Test\n"
+            << " Pixel mismatch at index "
+            << "[" << r << "," << c << "] "
+            << " @ " << width << "x" << height << " inv " << mask_type;
+      }
+    }
+  }
+}
+
+void BuildCompDiffwtdMaskD16Test::RunSpeedTest(
+    buildcompdiffwtdmaskd16_func test_impl) {
+  const int block_idx = GET_PARAM(2);
+  const int bd = GET_PARAM(0);
+  const int width = block_size_wide[block_idx];
+  const int height = block_size_high[block_idx];
+  DECLARE_ALIGNED(16, uint8_t, mask[MAX_SB_SQUARE]);
+  DECLARE_ALIGNED(32, uint16_t, src0[MAX_SB_SQUARE]);
+  DECLARE_ALIGNED(32, uint16_t, src1[MAX_SB_SQUARE]);
+
+  ConvolveParams conv_params =
+      get_conv_params_no_round(0, 0, 0, NULL, 0, 1, bd);
+
+  int in_precision =
+      bd + 2 * FILTER_BITS - conv_params.round_0 - conv_params.round_1 + 2;
+
+  for (int i = 0; i < MAX_SB_SQUARE; i++) {
+    src0[i] = rnd_.Rand16() & ((1 << in_precision) - 1);
+    src1[i] = rnd_.Rand16() & ((1 << in_precision) - 1);
+  }
+
+  const int num_loops = 1000000000 / (width + height);
+  aom_usec_timer timer;
+  aom_usec_timer_start(&timer);
+
+  for (int i = 0; i < num_loops; ++i)
+    av1_build_compound_diffwtd_mask_d16_c(mask, DIFFWTD_38, src0, width, src1,
+                                          width, height, width, &conv_params,
+                                          bd);
+
+  aom_usec_timer_mark(&timer);
+  const int elapsed_time = static_cast<int>(aom_usec_timer_elapsed(&timer));
+  printf("av1_build_compound_diffwtd_mask_d16 c_code %3dx%-3d: %7.2f us\n",
+         width, height, 1000.0 * elapsed_time / num_loops);
+
+  aom_usec_timer timer1;
+  aom_usec_timer_start(&timer1);
+
+  for (int i = 0; i < num_loops; ++i)
+    test_impl(mask, DIFFWTD_38, src0, width, src1, width, height, width,
+              &conv_params, bd);
+
+  aom_usec_timer_mark(&timer1);
+  const int elapsed_time1 = static_cast<int>(aom_usec_timer_elapsed(&timer1));
+  printf("av1_build_compound_diffwtd_mask_d16 test_code %3dx%-3d: %7.2f us\n",
+         width, height, 1000.0 * elapsed_time1 / num_loops);
+}
+
 void BuildCompDiffwtdMaskTest::RunTest(const int sb_type, const int is_speed,
                                        const DIFFWTD_MASK_TYPE type) {
   const int width = block_size_wide[sb_type];
@@ -88,7 +206,22 @@
   RunTest(GetParam(), 1, DIFFWTD_38_INV);
 }
 
+TEST_P(BuildCompDiffwtdMaskD16Test, CheckOutput) {
+  RunCheckOutput(GET_PARAM(1));
+}
+
+TEST_P(BuildCompDiffwtdMaskD16Test, DISABLED_Speed) {
+  RunSpeedTest(GET_PARAM(1));
+}
+
+#if HAVE_SSE4_1
 INSTANTIATE_TEST_CASE_P(SSE4_1, BuildCompDiffwtdMaskTest,
                         ::testing::Range(0, static_cast<int>(BLOCK_SIZES_ALL),
                                          1));
+
+INSTANTIATE_TEST_CASE_P(
+    SSE4_1, BuildCompDiffwtdMaskD16Test,
+    BuildParams(av1_build_compound_diffwtd_mask_d16_sse4_1));
+#endif
+
 }  // namespace