Add tests for low precision Hadamard transform

Change-Id: Ie858dc8707c23aacd049f3006cf80567e8013335
diff --git a/test/hadamard_test.cc b/test/hadamard_test.cc
index 0141125..30c45fc 100644
--- a/test/hadamard_test.cc
+++ b/test/hadamard_test.cc
@@ -23,16 +23,20 @@
 
 using libaom_test::ACMRandom;
 
-typedef void (*HadamardFunc)(const int16_t *a, ptrdiff_t a_stride,
-                             tran_low_t *b);
+using HadamardFunc = void (*)(const int16_t *a, ptrdiff_t a_stride,
+                              tran_low_t *b);
+// Low precision version of Hadamard Transform
+using HadamardLPFunc = void (*)(const int16_t *a, ptrdiff_t a_stride,
+                                int16_t *b);
 
-void HadamardLoop(const tran_low_t *a, tran_low_t *out) {
-  tran_low_t b[8];
+template <typename OutputType>
+void HadamardLoop(const OutputType *a, OutputType *out) {
+  OutputType b[8];
   for (int i = 0; i < 8; i += 2) {
     b[i + 0] = a[i * 8] + a[(i + 1) * 8];
     b[i + 1] = a[i * 8] - a[(i + 1) * 8];
   }
-  tran_low_t c[8];
+  OutputType c[8];
   for (int i = 0; i < 8; i += 4) {
     c[i + 0] = b[i + 0] + b[i + 2];
     c[i + 1] = b[i + 1] + b[i + 3];
@@ -49,19 +53,21 @@
   out[5] = c[3] - c[7];
 }
 
-void ReferenceHadamard8x8(const int16_t *a, int a_stride, tran_low_t *b) {
-  tran_low_t input[64];
-  tran_low_t buf[64];
+template <typename OutputType>
+void ReferenceHadamard8x8(const int16_t *a, int a_stride, OutputType *b) {
+  OutputType input[64];
+  OutputType buf[64];
   for (int i = 0; i < 8; ++i) {
     for (int j = 0; j < 8; ++j) {
-      input[i * 8 + j] = static_cast<tran_low_t>(a[i * a_stride + j]);
+      input[i * 8 + j] = static_cast<OutputType>(a[i * a_stride + j]);
     }
   }
   for (int i = 0; i < 8; ++i) HadamardLoop(input + i, buf + i * 8);
   for (int i = 0; i < 8; ++i) HadamardLoop(buf + i, b + i * 8);
 }
 
-void ReferenceHadamard16x16(const int16_t *a, int a_stride, tran_low_t *b) {
+template <typename OutputType>
+void ReferenceHadamard16x16(const int16_t *a, int a_stride, OutputType *b) {
   /* The source is a 16x16 block. The destination is rearranged to 8x32.
    * Input is 9 bit. */
   ReferenceHadamard8x8(a + 0 + 0 * a_stride, a_stride, b + 0);
@@ -72,16 +78,16 @@
   /* Overlay the 8x8 blocks and combine. */
   for (int i = 0; i < 64; ++i) {
     /* 8x8 steps the range up to 15 bits. */
-    const tran_low_t a0 = b[0];
-    const tran_low_t a1 = b[64];
-    const tran_low_t a2 = b[128];
-    const tran_low_t a3 = b[192];
+    const OutputType a0 = b[0];
+    const OutputType a1 = b[64];
+    const OutputType a2 = b[128];
+    const OutputType a3 = b[192];
 
     /* Prevent the result from escaping int16_t. */
-    const tran_low_t b0 = (a0 + a1) >> 1;
-    const tran_low_t b1 = (a0 - a1) >> 1;
-    const tran_low_t b2 = (a2 + a3) >> 1;
-    const tran_low_t b3 = (a2 - a3) >> 1;
+    const OutputType b0 = (a0 + a1) >> 1;
+    const OutputType b1 = (a0 - a1) >> 1;
+    const OutputType b2 = (a2 + a3) >> 1;
+    const OutputType b3 = (a2 - a3) >> 1;
 
     /* Store a 16 bit value. */
     b[0] = b0 + b2;
@@ -93,22 +99,23 @@
   }
 }
 
-void ReferenceHadamard32x32(const int16_t *a, int a_stride, tran_low_t *b) {
+template <typename OutputType>
+void ReferenceHadamard32x32(const int16_t *a, int a_stride, OutputType *b) {
   ReferenceHadamard16x16(a + 0 + 0 * a_stride, a_stride, b + 0);
   ReferenceHadamard16x16(a + 16 + 0 * a_stride, a_stride, b + 256);
   ReferenceHadamard16x16(a + 0 + 16 * a_stride, a_stride, b + 512);
   ReferenceHadamard16x16(a + 16 + 16 * a_stride, a_stride, b + 768);
 
   for (int i = 0; i < 256; ++i) {
-    const tran_low_t a0 = b[0];
-    const tran_low_t a1 = b[256];
-    const tran_low_t a2 = b[512];
-    const tran_low_t a3 = b[768];
+    const OutputType a0 = b[0];
+    const OutputType a1 = b[256];
+    const OutputType a2 = b[512];
+    const OutputType a3 = b[768];
 
-    const tran_low_t b0 = (a0 + a1) >> 2;
-    const tran_low_t b1 = (a0 - a1) >> 2;
-    const tran_low_t b2 = (a2 + a3) >> 2;
-    const tran_low_t b3 = (a2 - a3) >> 2;
+    const OutputType b0 = (a0 + a1) >> 2;
+    const OutputType b1 = (a0 - a1) >> 2;
+    const OutputType b2 = (a2 + a3) >> 2;
+    const OutputType b3 = (a2 - a3) >> 2;
 
     b[0] = b0 + b2;
     b[256] = b1 + b3;
@@ -119,45 +126,57 @@
   }
 }
 
-struct HadamardFuncWithSize {
-  HadamardFuncWithSize(HadamardFunc f, int s) : func(f), block_size(s) {}
-  HadamardFunc func;
+template <typename OutputType>
+void ReferenceHadamard(const int16_t *a, int a_stride, OutputType *b, int bwh) {
+  if (bwh == 32)
+    ReferenceHadamard32x32(a, a_stride, b);
+  else if (bwh == 16)
+    ReferenceHadamard16x16(a, a_stride, b);
+  else if (bwh == 8) {
+    ReferenceHadamard8x8(a, a_stride, b);
+  } else {
+    GTEST_FAIL() << "Invalid Hadamard transform size " << bwh << std::endl;
+  }
+}
+
+template <typename HadamardFuncType>
+struct FuncWithSize {
+  FuncWithSize(HadamardFuncType f, int s) : func(f), block_size(s) {}
+  HadamardFuncType func;
   int block_size;
 };
 
-std::ostream &operator<<(std::ostream &os, const HadamardFuncWithSize &hfs) {
+using HadamardFuncWithSize = FuncWithSize<HadamardFunc>;
+using HadamardLPFuncWithSize = FuncWithSize<HadamardLPFunc>;
+
+template <typename HadamardFuncType>
+std::ostream &operator<<(std::ostream &os,
+                         const FuncWithSize<HadamardFuncType> &hfs) {
   return os << "block size: " << hfs.block_size;
 }
 
-class HadamardTestBase : public ::testing::TestWithParam<HadamardFuncWithSize> {
+template <typename OutputType, typename HadamardFuncType>
+class HadamardTestBase
+    : public ::testing::TestWithParam<FuncWithSize<HadamardFuncType>> {
  public:
-  virtual void SetUp() {
-    h_func_ = GetParam().func;
-    bwh_ = GetParam().block_size;
+  explicit HadamardTestBase(const FuncWithSize<HadamardFuncType> &func_param) {
+    h_func_ = func_param.func;
+    bwh_ = func_param.block_size;
     block_size_ = bwh_ * bwh_;
-    rnd_.Reset(ACMRandom::DeterministicSeed());
   }
 
+  virtual void SetUp() { rnd_.Reset(ACMRandom::DeterministicSeed()); }
+
   virtual int16_t Rand() = 0;
 
-  void ReferenceHadamard(const int16_t *a, int a_stride, tran_low_t *b,
-                         int bwh) {
-    if (bwh == 32)
-      ReferenceHadamard32x32(a, a_stride, b);
-    else if (bwh == 16)
-      ReferenceHadamard16x16(a, a_stride, b);
-    else
-      ReferenceHadamard8x8(a, a_stride, b);
-  }
-
   void CompareReferenceRandom() {
     const int kMaxBlockSize = 32 * 32;
     DECLARE_ALIGNED(16, int16_t, a[kMaxBlockSize]);
-    DECLARE_ALIGNED(16, tran_low_t, b[kMaxBlockSize]);
+    DECLARE_ALIGNED(16, OutputType, b[kMaxBlockSize]);
     memset(a, 0, sizeof(a));
     memset(b, 0, sizeof(b));
 
-    tran_low_t b_ref[kMaxBlockSize];
+    OutputType b_ref[kMaxBlockSize];
     memset(b_ref, 0, sizeof(b_ref));
 
     for (int i = 0; i < block_size_; ++i) a[i] = Rand();
@@ -174,11 +193,11 @@
   void VaryStride() {
     const int kMaxBlockSize = 32 * 32;
     DECLARE_ALIGNED(16, int16_t, a[kMaxBlockSize * 8]);
-    DECLARE_ALIGNED(16, tran_low_t, b[kMaxBlockSize]);
+    DECLARE_ALIGNED(16, OutputType, b[kMaxBlockSize]);
     memset(a, 0, sizeof(a));
     for (int i = 0; i < block_size_ * 8; ++i) a[i] = Rand();
 
-    tran_low_t b_ref[kMaxBlockSize];
+    OutputType b_ref[kMaxBlockSize];
     for (int i = 8; i < 64; i += 8) {
       memset(b, 0, sizeof(b));
       memset(b_ref, 0, sizeof(b_ref));
@@ -196,7 +215,7 @@
   void SpeedTest(int times) {
     const int kMaxBlockSize = 32 * 32;
     DECLARE_ALIGNED(16, int16_t, input[kMaxBlockSize]);
-    DECLARE_ALIGNED(16, tran_low_t, output[kMaxBlockSize]);
+    DECLARE_ALIGNED(16, OutputType, output[kMaxBlockSize]);
     memset(input, 1, sizeof(input));
     memset(output, 0, sizeof(output));
 
@@ -217,11 +236,12 @@
  private:
   int bwh_;
   int block_size_;
-  HadamardFunc h_func_;
+  HadamardFuncType h_func_;
 };
 
-class HadamardLowbdTest : public HadamardTestBase {
+class HadamardLowbdTest : public HadamardTestBase<tran_low_t, HadamardFunc> {
  public:
+  HadamardLowbdTest() : HadamardTestBase(GetParam()) {}
   virtual int16_t Rand() { return rnd_.Rand9Signed(); }
 };
 
@@ -229,6 +249,8 @@
 
 TEST_P(HadamardLowbdTest, VaryStride) { VaryStride(); }
 
+TEST_P(HadamardLowbdTest, DISABLED_SpeedTest) { SpeedTest(1000000); }
+
 INSTANTIATE_TEST_SUITE_P(
     C, HadamardLowbdTest,
     ::testing::Values(HadamardFuncWithSize(&aom_hadamard_8x8_c, 8),
@@ -257,4 +279,43 @@
                       HadamardFuncWithSize(&aom_hadamard_16x16_neon, 16)));
 #endif  // HAVE_NEON
 
+// Tests for low precision
+class HadamardLowbdLPTest : public HadamardTestBase<int16_t, HadamardLPFunc> {
+ public:
+  HadamardLowbdLPTest() : HadamardTestBase(GetParam()) {}
+  virtual int16_t Rand() { return rnd_.Rand9Signed(); }
+};
+
+TEST_P(HadamardLowbdLPTest, CompareReferenceRandom) {
+  CompareReferenceRandom();
+}
+
+TEST_P(HadamardLowbdLPTest, VaryStride) { VaryStride(); }
+
+TEST_P(HadamardLowbdLPTest, DISABLED_SpeedTest) { SpeedTest(1000000); }
+
+INSTANTIATE_TEST_SUITE_P(
+    C, HadamardLowbdLPTest,
+    ::testing::Values(HadamardLPFuncWithSize(&aom_hadamard_lp_8x8_c, 8),
+                      HadamardLPFuncWithSize(&aom_hadamard_lp_16x16_c, 16)));
+
+#if HAVE_SSE2
+INSTANTIATE_TEST_SUITE_P(
+    SSE2, HadamardLowbdLPTest,
+    ::testing::Values(HadamardLPFuncWithSize(&aom_hadamard_lp_8x8_sse2, 8)));
+#endif  // HAVE_SSE2
+
+#if HAVE_AVX2
+INSTANTIATE_TEST_SUITE_P(
+    AVX2, HadamardLowbdLPTest,
+    ::testing::Values(HadamardLPFuncWithSize(&aom_hadamard_lp_16x16_avx2, 16)));
+#endif  // HAVE_AVX2
+
+#if HAVE_NEON
+INSTANTIATE_TEST_SUITE_P(
+    NEON, HadamardLowbdLPTest,
+    ::testing::Values(HadamardLPFuncWithSize(&aom_hadamard_lp_8x8_neon, 8),
+                      HadamardLPFuncWithSize(&aom_hadamard_lp_16x16_neon, 16)));
+#endif  // HAVE_NEON
+
 }  // namespace