Add an SSE implementation of av1_highbd_iwht4x4_16_add

This is actually used in lossless and the SSE implementation is
more than 3 times faster.

Bug: b/191463451

Change-Id: Iaf7586f339a9679d917167faa311daa8277a469a
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index c4777ea..eedea61 100644
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -214,6 +214,7 @@
 
 add_proto qw/void av1_highbd_iwht4x4_1_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride, int bd";
 add_proto qw/void av1_highbd_iwht4x4_16_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride, int bd";
+specialize qw/av1_highbd_iwht4x4_16_add  sse4_1/;
 
 add_proto qw/void av1_inv_txfm2d_add_4x8/, "const int32_t *input, uint16_t *output, int stride, TX_TYPE tx_type, int bd";
 add_proto qw/void av1_inv_txfm2d_add_8x4/, "const int32_t *input, uint16_t *output, int stride, TX_TYPE tx_type, int bd";
diff --git a/av1/common/x86/highbd_inv_txfm_sse4.c b/av1/common/x86/highbd_inv_txfm_sse4.c
index 03eaef8..568ee5c 100644
--- a/av1/common/x86/highbd_inv_txfm_sse4.c
+++ b/av1/common/x86/highbd_inv_txfm_sse4.c
@@ -145,6 +145,74 @@
   in[3] = _mm_load_si128((const __m128i *)(coeff + 12));
 }
 
+void av1_highbd_iwht4x4_16_add_sse4_1(const tran_low_t *input, uint8_t *dest8,
+                                      int stride, int bd) {
+  /* 4-point reversible, orthonormal inverse Walsh-Hadamard in 3.5 adds,
+     0.5 shifts per pixel. */
+  __m128i op[4];
+  uint16_t *dest = CONVERT_TO_SHORTPTR(dest8);
+
+  load_buffer_4x4(input, op);
+
+  // Shift before-hand.
+  op[0] = _mm_srai_epi32(op[0], UNIT_QUANT_SHIFT);
+  op[1] = _mm_srai_epi32(op[1], UNIT_QUANT_SHIFT);
+  op[2] = _mm_srai_epi32(op[2], UNIT_QUANT_SHIFT);
+  op[3] = _mm_srai_epi32(op[3], UNIT_QUANT_SHIFT);
+
+  for (int i = 0; i < 2; ++i) {
+    transpose_32bit_4x4(op, op);
+
+    __m128i a1 = op[0];
+    __m128i c1 = op[1];
+    __m128i d1 = op[2];
+    __m128i b1 = op[3];
+    a1 = _mm_add_epi32(a1, c1);          // a1 += c1
+    d1 = _mm_sub_epi32(d1, b1);          // d1 -= b1
+    __m128i e1 = _mm_sub_epi32(a1, d1);  // e1 = (a1 - d1) >> 1
+    e1 = _mm_srai_epi32(e1, 1);
+    b1 = _mm_sub_epi32(e1, b1);  // b1 = e1 - b1
+    c1 = _mm_sub_epi32(e1, c1);  // c1 = e1 - c1
+    a1 = _mm_sub_epi32(a1, b1);  // a1 -= b1
+    d1 = _mm_add_epi32(d1, c1);  // d1 += c1
+
+    op[0] = a1;
+    op[1] = b1;
+    op[2] = c1;
+    op[3] = d1;
+  }
+
+  // Convert to int16_t. The C code checks that we are in range.
+  op[0] = _mm_packs_epi32(op[0], op[1]);
+  op[1] = _mm_packs_epi32(op[2], op[3]);
+
+  // Load uint16_t.
+  __m128i dst[2];
+  __m128i tmp[4];
+  tmp[0] = _mm_loadl_epi64((const __m128i *)(dest + 0 * stride));
+  tmp[1] = _mm_loadl_epi64((const __m128i *)(dest + 1 * stride));
+  dst[0] = _mm_unpacklo_epi64(tmp[0], tmp[1]);
+  tmp[2] = _mm_loadl_epi64((const __m128i *)(dest + 2 * stride));
+  tmp[3] = _mm_loadl_epi64((const __m128i *)(dest + 3 * stride));
+  dst[1] = _mm_unpacklo_epi64(tmp[2], tmp[3]);
+
+  // Add to the previous results.
+  dst[0] = _mm_add_epi16(dst[0], op[0]);
+  dst[1] = _mm_add_epi16(dst[1], op[1]);
+
+  // Clamp.
+  dst[0] = highbd_clamp_epi16(dst[0], bd);
+  dst[1] = highbd_clamp_epi16(dst[1], bd);
+
+  // Store.
+  _mm_storel_epi64((__m128i *)(dest + 0 * stride), dst[0]);
+  dst[0] = _mm_srli_si128(dst[0], 8);
+  _mm_storel_epi64((__m128i *)(dest + 1 * stride), dst[0]);
+  _mm_storel_epi64((__m128i *)(dest + 2 * stride), dst[1]);
+  dst[1] = _mm_srli_si128(dst[1], 8);
+  _mm_storel_epi64((__m128i *)(dest + 3 * stride), dst[1]);
+}
+
 static void addsub_sse4_1(const __m128i in0, const __m128i in1, __m128i *out0,
                           __m128i *out1, const __m128i *clamp_lo,
                           const __m128i *clamp_hi) {
diff --git a/test/fwht4x4_test.cc b/test/fwht4x4_test.cc
index 2b470c1..2e27adf 100644
--- a/test/fwht4x4_test.cc
+++ b/test/fwht4x4_test.cc
@@ -44,14 +44,26 @@
   av1_fwht4x4_c(in, out, stride);
 }
 
-void iwht4x4_10(const tran_low_t *in, uint8_t *out, int stride) {
+void iwht4x4_10_c(const tran_low_t *in, uint8_t *out, int stride) {
   av1_highbd_iwht4x4_16_add_c(in, out, stride, 10);
 }
 
-void iwht4x4_12(const tran_low_t *in, uint8_t *out, int stride) {
+void iwht4x4_12_c(const tran_low_t *in, uint8_t *out, int stride) {
   av1_highbd_iwht4x4_16_add_c(in, out, stride, 12);
 }
 
+#if HAVE_SSE4_1
+
+void iwht4x4_10_sse4_1(const tran_low_t *in, uint8_t *out, int stride) {
+  av1_highbd_iwht4x4_16_add_sse4_1(in, out, stride, 10);
+}
+
+void iwht4x4_12_sse4_1(const tran_low_t *in, uint8_t *out, int stride) {
+  av1_highbd_iwht4x4_16_add_sse4_1(in, out, stride, 12);
+}
+
+#endif
+
 class Trans4x4WHT : public libaom_test::TransformTestBase<tran_low_t>,
                     public ::testing::TestWithParam<Dct4x4Param> {
  public:
@@ -176,19 +188,20 @@
 
 INSTANTIATE_TEST_SUITE_P(
     C, Trans4x4WHT,
-    ::testing::Values(make_tuple(&av1_highbd_fwht4x4_c, &iwht4x4_10, DCT_DCT,
+    ::testing::Values(make_tuple(&av1_highbd_fwht4x4_c, &iwht4x4_10_c, DCT_DCT,
                                  AOM_BITS_10, 16, static_cast<FdctFunc>(NULL)),
-                      make_tuple(&av1_highbd_fwht4x4_c, &iwht4x4_12, DCT_DCT,
+                      make_tuple(&av1_highbd_fwht4x4_c, &iwht4x4_12_c, DCT_DCT,
                                  AOM_BITS_12, 16,
                                  static_cast<FdctFunc>(NULL))));
+
 #if HAVE_SSE4_1
 
 INSTANTIATE_TEST_SUITE_P(
     SSE4_1, Trans4x4WHT,
-    ::testing::Values(make_tuple(&av1_highbd_fwht4x4_sse4_1, &iwht4x4_10,
+    ::testing::Values(make_tuple(&av1_highbd_fwht4x4_sse4_1, &iwht4x4_10_sse4_1,
                                  DCT_DCT, AOM_BITS_10, 16,
                                  static_cast<FdctFunc>(NULL)),
-                      make_tuple(&av1_highbd_fwht4x4_sse4_1, &iwht4x4_12,
+                      make_tuple(&av1_highbd_fwht4x4_sse4_1, &iwht4x4_12_sse4_1,
                                  DCT_DCT, AOM_BITS_12, 16,
                                  static_cast<FdctFunc>(NULL))));
 
@@ -198,10 +211,12 @@
 
 INSTANTIATE_TEST_SUITE_P(
     NEON, Trans4x4WHT,
-    ::testing::Values(make_tuple(&av1_highbd_fwht4x4_neon, &iwht4x4_10, DCT_DCT,
-                                 AOM_BITS_10, 16, &av1_highbd_fwht4x4_c),
-                      make_tuple(&av1_highbd_fwht4x4_neon, &iwht4x4_12, DCT_DCT,
-                                 AOM_BITS_12, 16, &av1_highbd_fwht4x4_c)));
+    ::testing::Values(make_tuple(&av1_highbd_fwht4x4_neon, &iwht4x4_10_c,
+                                 DCT_DCT, AOM_BITS_10, 16,
+                                 &av1_highbd_fwht4x4_c),
+                      make_tuple(&av1_highbd_fwht4x4_neon, &iwht4x4_12_c,
+                                 DCT_DCT, AOM_BITS_12, 16,
+                                 &av1_highbd_fwht4x4_c)));
 
 #endif  // HAVE_NEON