Add neon/sse2 optimization for fwd txfm 4x4

Add unit test to verify C/SSE2/NEON are bitexact.

Change-Id: I24bdee1387622510436ed149e7b61c937488a735
diff --git a/aom_dsp/aom_dsp_rtcd_defs.pl b/aom_dsp/aom_dsp_rtcd_defs.pl
index c84a380..b2b4b19 100755
--- a/aom_dsp/aom_dsp_rtcd_defs.pl
+++ b/aom_dsp/aom_dsp_rtcd_defs.pl
@@ -481,8 +481,10 @@
 #
 if (aom_config("CONFIG_AV1_ENCODER") eq "yes"){
     add_proto qw/void aom_fdct4x4/, "const int16_t *input, tran_low_t *output, int stride";
+    specialize qw/aom_fdct4x4 neon sse2/;
 
     add_proto qw/void aom_fdct4x4_lp/, "const int16_t *input, int16_t *output, int stride";
+    specialize qw/aom_fdct4x4_lp neon sse2/;
 
     add_proto qw/void aom_fdct8x8/, "const int16_t *input, tran_low_t *output, int stride";
     specialize qw/aom_fdct8x8 neon sse2/, "$ssse3_x86_64";
diff --git a/aom_dsp/arm/fwd_txfm_neon.c b/aom_dsp/arm/fwd_txfm_neon.c
index e4300c9..ce93523 100644
--- a/aom_dsp/arm/fwd_txfm_neon.c
+++ b/aom_dsp/arm/fwd_txfm_neon.c
@@ -14,9 +14,103 @@
 #include "config/aom_config.h"
 
 #include "aom_dsp/txfm_common.h"
+#include "av1/common/arm/mem_neon.h"
+#include "av1/common/arm/transpose_neon.h"
+
+static void aom_fdct4x4_helper(const int16_t *input, int stride,
+                               int16x4_t *input_0, int16x4_t *input_1,
+                               int16x4_t *input_2, int16x4_t *input_3) {
+  *input_0 = vshl_n_s16(vld1_s16(input + 0 * stride), 4);
+  *input_1 = vshl_n_s16(vld1_s16(input + 1 * stride), 4);
+  *input_2 = vshl_n_s16(vld1_s16(input + 2 * stride), 4);
+  *input_3 = vshl_n_s16(vld1_s16(input + 3 * stride), 4);
+  // If the very first value != 0, then add 1.
+  if (input[0] != 0) {
+    const int16x4_t one = vreinterpret_s16_s64(vdup_n_s64(1));
+    *input_0 = vadd_s16(*input_0, one);
+  }
+
+  for (int i = 0; i < 2; ++i) {
+    const int16x8_t input_01 = vcombine_s16(*input_0, *input_1);
+    const int16x8_t input_32 = vcombine_s16(*input_3, *input_2);
+
+    // in_0 +/- in_3, in_1 +/- in_2
+    const int16x8_t s_01 = vaddq_s16(input_01, input_32);
+    const int16x8_t s_32 = vsubq_s16(input_01, input_32);
+
+    // step_0 +/- step_1, step_2 +/- step_3
+    const int16x4_t s_0 = vget_low_s16(s_01);
+    const int16x4_t s_1 = vget_high_s16(s_01);
+    const int16x4_t s_2 = vget_high_s16(s_32);
+    const int16x4_t s_3 = vget_low_s16(s_32);
+
+    // (s_0 +/- s_1) * cospi_16_64
+    // Must expand all elements to s32. See 'needs32' comment in fwd_txfm.c.
+    const int32x4_t s_0_p_s_1 = vaddl_s16(s_0, s_1);
+    const int32x4_t s_0_m_s_1 = vsubl_s16(s_0, s_1);
+    const int32x4_t temp1 = vmulq_n_s32(s_0_p_s_1, cospi_16_64);
+    const int32x4_t temp2 = vmulq_n_s32(s_0_m_s_1, cospi_16_64);
+
+    // fdct_round_shift
+    int16x4_t out_0 = vrshrn_n_s32(temp1, DCT_CONST_BITS);
+    int16x4_t out_2 = vrshrn_n_s32(temp2, DCT_CONST_BITS);
+
+    // s_3 * cospi_8_64 + s_2 * cospi_24_64
+    // s_3 * cospi_24_64 - s_2 * cospi_8_64
+    const int32x4_t s_3_cospi_8_64 = vmull_n_s16(s_3, cospi_8_64);
+    const int32x4_t s_3_cospi_24_64 = vmull_n_s16(s_3, cospi_24_64);
+
+    const int32x4_t temp3 = vmlal_n_s16(s_3_cospi_8_64, s_2, cospi_24_64);
+    const int32x4_t temp4 = vmlsl_n_s16(s_3_cospi_24_64, s_2, cospi_8_64);
+
+    // fdct_round_shift
+    int16x4_t out_1 = vrshrn_n_s32(temp3, DCT_CONST_BITS);
+    int16x4_t out_3 = vrshrn_n_s32(temp4, DCT_CONST_BITS);
+
+    transpose_s16_4x4d(&out_0, &out_1, &out_2, &out_3);
+
+    *input_0 = out_0;
+    *input_1 = out_1;
+    *input_2 = out_2;
+    *input_3 = out_3;
+  }
+}
+
+void aom_fdct4x4_neon(const int16_t *input, tran_low_t *final_output,
+                      int stride) {
+  // input[M * stride] * 16
+  int16x4_t input_0, input_1, input_2, input_3;
+
+  aom_fdct4x4_helper(input, stride, &input_0, &input_1, &input_2, &input_3);
+
+  // Not quite a rounding shift. Only add 1 despite shifting by 2.
+  const int16x8_t one = vdupq_n_s16(1);
+  int16x8_t out_01 = vcombine_s16(input_0, input_1);
+  int16x8_t out_23 = vcombine_s16(input_2, input_3);
+  out_01 = vshrq_n_s16(vaddq_s16(out_01, one), 2);
+  out_23 = vshrq_n_s16(vaddq_s16(out_23, one), 2);
+  store_s16q_to_tran_low(final_output + 0 * 8, out_01);
+  store_s16q_to_tran_low(final_output + 1 * 8, out_23);
+}
+
+void aom_fdct4x4_lp_neon(const int16_t *input, int16_t *final_output,
+                         int stride) {
+  // input[M * stride] * 16
+  int16x4_t input_0, input_1, input_2, input_3;
+
+  aom_fdct4x4_helper(input, stride, &input_0, &input_1, &input_2, &input_3);
+
+  // Not quite a rounding shift. Only add 1 despite shifting by 2.
+  const int16x8_t one = vdupq_n_s16(1);
+  int16x8_t out_01 = vcombine_s16(input_0, input_1);
+  int16x8_t out_23 = vcombine_s16(input_2, input_3);
+  out_01 = vshrq_n_s16(vaddq_s16(out_01, one), 2);
+  out_23 = vshrq_n_s16(vaddq_s16(out_23, one), 2);
+  vst1q_s16(final_output + 0 * 8, out_01);
+  vst1q_s16(final_output + 1 * 8, out_23);
+}
 
 void aom_fdct8x8_neon(const int16_t *input, int16_t *final_output, int stride) {
-  int i;
   // stage 1
   int16x8_t input_0 = vshlq_n_s16(vld1q_s16(&input[0 * stride]), 2);
   int16x8_t input_1 = vshlq_n_s16(vld1q_s16(&input[1 * stride]), 2);
@@ -26,7 +120,7 @@
   int16x8_t input_5 = vshlq_n_s16(vld1q_s16(&input[5 * stride]), 2);
   int16x8_t input_6 = vshlq_n_s16(vld1q_s16(&input[6 * stride]), 2);
   int16x8_t input_7 = vshlq_n_s16(vld1q_s16(&input[7 * stride]), 2);
-  for (i = 0; i < 2; ++i) {
+  for (int i = 0; i < 2; ++i) {
     int16x8_t out_0, out_1, out_2, out_3, out_4, out_5, out_6, out_7;
     const int16x8_t v_s0 = vaddq_s16(input_0, input_7);
     const int16x8_t v_s1 = vaddq_s16(input_1, input_6);
diff --git a/aom_dsp/x86/fwd_txfm_impl_sse2.h b/aom_dsp/x86/fwd_txfm_impl_sse2.h
index 1e3d13e..89fe189 100644
--- a/aom_dsp/x86/fwd_txfm_impl_sse2.h
+++ b/aom_dsp/x86/fwd_txfm_impl_sse2.h
@@ -30,6 +30,206 @@
 #define SUB_EPI16 _mm_sub_epi16
 #endif
 
+static void FDCT4x4_2D_HELPER(const int16_t *input, int stride, __m128i *in0,
+                              __m128i *in1) {
+  // Constants
+  // These are the coefficients used for the multiplies.
+  // In the comments, pN means cos(N pi /64) and mN is -cos(N pi /64),
+  // where cospi_N_64 = cos(N pi /64)
+  const __m128i k__cospi_A =
+      octa_set_epi16(cospi_16_64, cospi_16_64, cospi_16_64, cospi_16_64,
+                     cospi_16_64, -cospi_16_64, cospi_16_64, -cospi_16_64);
+  const __m128i k__cospi_B =
+      octa_set_epi16(cospi_16_64, -cospi_16_64, cospi_16_64, -cospi_16_64,
+                     cospi_16_64, cospi_16_64, cospi_16_64, cospi_16_64);
+  const __m128i k__cospi_C =
+      octa_set_epi16(cospi_8_64, cospi_24_64, cospi_8_64, cospi_24_64,
+                     cospi_24_64, -cospi_8_64, cospi_24_64, -cospi_8_64);
+  const __m128i k__cospi_D =
+      octa_set_epi16(cospi_24_64, -cospi_8_64, cospi_24_64, -cospi_8_64,
+                     cospi_8_64, cospi_24_64, cospi_8_64, cospi_24_64);
+  const __m128i k__cospi_E =
+      octa_set_epi16(cospi_16_64, cospi_16_64, cospi_16_64, cospi_16_64,
+                     cospi_16_64, cospi_16_64, cospi_16_64, cospi_16_64);
+  const __m128i k__cospi_F =
+      octa_set_epi16(cospi_16_64, -cospi_16_64, cospi_16_64, -cospi_16_64,
+                     cospi_16_64, -cospi_16_64, cospi_16_64, -cospi_16_64);
+  const __m128i k__cospi_G =
+      octa_set_epi16(cospi_8_64, cospi_24_64, cospi_8_64, cospi_24_64,
+                     -cospi_8_64, -cospi_24_64, -cospi_8_64, -cospi_24_64);
+  const __m128i k__cospi_H =
+      octa_set_epi16(cospi_24_64, -cospi_8_64, cospi_24_64, -cospi_8_64,
+                     -cospi_24_64, cospi_8_64, -cospi_24_64, cospi_8_64);
+
+  const __m128i k__DCT_CONST_ROUNDING = _mm_set1_epi32(DCT_CONST_ROUNDING);
+  // This second rounding constant saves doing some extra adds at the end
+  const __m128i k__DCT_CONST_ROUNDING2 =
+      _mm_set1_epi32(DCT_CONST_ROUNDING + (DCT_CONST_ROUNDING << 1));
+  const int DCT_CONST_BITS2 = DCT_CONST_BITS + 2;
+  const __m128i k__nonzero_bias_a = _mm_setr_epi16(0, 1, 1, 1, 1, 1, 1, 1);
+  const __m128i k__nonzero_bias_b = _mm_setr_epi16(1, 0, 0, 0, 0, 0, 0, 0);
+
+  // Load inputs.
+  *in0 = _mm_loadl_epi64((const __m128i *)(input + 0 * stride));
+  *in1 = _mm_loadl_epi64((const __m128i *)(input + 1 * stride));
+  *in1 = _mm_unpacklo_epi64(
+      *in1, _mm_loadl_epi64((const __m128i *)(input + 2 * stride)));
+  *in0 = _mm_unpacklo_epi64(
+      *in0, _mm_loadl_epi64((const __m128i *)(input + 3 * stride)));
+  // in0 = [i0 i1 i2 i3 iC iD iE iF]
+  // in1 = [i4 i5 i6 i7 i8 i9 iA iB]
+  // multiply by 16 to give some extra precision
+  *in0 = _mm_slli_epi16(*in0, 4);
+  *in1 = _mm_slli_epi16(*in1, 4);
+  // if (i == 0 && input[0]) input[0] += 1;
+  // add 1 to the upper left pixel if it is non-zero, which helps reduce
+  // the round-trip error
+  {
+    // The mask will only contain whether the first value is zero, all
+    // other comparison will fail as something shifted by 4 (above << 4)
+    // can never be equal to one. To increment in the non-zero case, we
+    // add the mask and one for the first element:
+    //   - if zero, mask = -1, v = v - 1 + 1 = v
+    //   - if non-zero, mask = 0, v = v + 0 + 1 = v + 1
+    __m128i mask = _mm_cmpeq_epi16(*in0, k__nonzero_bias_a);
+    *in0 = _mm_add_epi16(*in0, mask);
+    *in0 = _mm_add_epi16(*in0, k__nonzero_bias_b);
+  }
+  // There are 4 total stages, alternating between an add/subtract stage
+  // followed by an multiply-and-add stage.
+  {
+    // Stage 1: Add/subtract
+
+    // in0 = [i0 i1 i2 i3 iC iD iE iF]
+    // in1 = [i4 i5 i6 i7 i8 i9 iA iB]
+    const __m128i r0 = _mm_unpacklo_epi16(*in0, *in1);
+    const __m128i r1 = _mm_unpackhi_epi16(*in0, *in1);
+    // r0 = [i0 i4 i1 i5 i2 i6 i3 i7]
+    // r1 = [iC i8 iD i9 iE iA iF iB]
+    const __m128i r2 = _mm_shuffle_epi32(r0, 0xB4);
+    const __m128i r3 = _mm_shuffle_epi32(r1, 0xB4);
+    // r2 = [i0 i4 i1 i5 i3 i7 i2 i6]
+    // r3 = [iC i8 iD i9 iF iB iE iA]
+
+    const __m128i t0 = _mm_add_epi16(r2, r3);
+    const __m128i t1 = _mm_sub_epi16(r2, r3);
+    // t0 = [a0 a4 a1 a5 a3 a7 a2 a6]
+    // t1 = [aC a8 aD a9 aF aB aE aA]
+
+    // Stage 2: multiply by constants (which gets us into 32 bits).
+    // The constants needed here are:
+    // k__cospi_A = [p16 p16 p16 p16 p16 m16 p16 m16]
+    // k__cospi_B = [p16 m16 p16 m16 p16 p16 p16 p16]
+    // k__cospi_C = [p08 p24 p08 p24 p24 m08 p24 m08]
+    // k__cospi_D = [p24 m08 p24 m08 p08 p24 p08 p24]
+    const __m128i u0 = _mm_madd_epi16(t0, k__cospi_A);
+    const __m128i u2 = _mm_madd_epi16(t0, k__cospi_B);
+    const __m128i u1 = _mm_madd_epi16(t1, k__cospi_C);
+    const __m128i u3 = _mm_madd_epi16(t1, k__cospi_D);
+    // Then add and right-shift to get back to 16-bit range
+    const __m128i v0 = _mm_add_epi32(u0, k__DCT_CONST_ROUNDING);
+    const __m128i v1 = _mm_add_epi32(u1, k__DCT_CONST_ROUNDING);
+    const __m128i v2 = _mm_add_epi32(u2, k__DCT_CONST_ROUNDING);
+    const __m128i v3 = _mm_add_epi32(u3, k__DCT_CONST_ROUNDING);
+    const __m128i w0 = _mm_srai_epi32(v0, DCT_CONST_BITS);
+    const __m128i w1 = _mm_srai_epi32(v1, DCT_CONST_BITS);
+    const __m128i w2 = _mm_srai_epi32(v2, DCT_CONST_BITS);
+    const __m128i w3 = _mm_srai_epi32(v3, DCT_CONST_BITS);
+    // w0 = [b0 b1 b7 b6]
+    // w1 = [b8 b9 bF bE]
+    // w2 = [b4 b5 b3 b2]
+    // w3 = [bC bD bB bA]
+    const __m128i x0 = _mm_packs_epi32(w0, w1);
+    const __m128i x1 = _mm_packs_epi32(w2, w3);
+
+    // x0 = [b0 b1 b7 b6 b8 b9 bF bE]
+    // x1 = [b4 b5 b3 b2 bC bD bB bA]
+    *in0 = _mm_shuffle_epi32(x0, 0xD8);
+    *in1 = _mm_shuffle_epi32(x1, 0x8D);
+    // in0 = [b0 b1 b8 b9 b7 b6 bF bE]
+    // in1 = [b3 b2 bB bA b4 b5 bC bD]
+  }
+  {
+    // vertical DCTs finished. Now we do the horizontal DCTs.
+    // Stage 3: Add/subtract
+
+    const __m128i t0 = ADD_EPI16(*in0, *in1);
+    const __m128i t1 = SUB_EPI16(*in0, *in1);
+
+    // Stage 4: multiply by constants (which gets us into 32 bits).
+    {
+      // The constants needed here are:
+      // k__cospi_E = [p16 p16 p16 p16 p16 p16 p16 p16]
+      // k__cospi_F = [p16 m16 p16 m16 p16 m16 p16 m16]
+      // k__cospi_G = [p08 p24 p08 p24 m08 m24 m08 m24]
+      // k__cospi_H = [p24 m08 p24 m08 m24 p08 m24 p08]
+      const __m128i u0 = _mm_madd_epi16(t0, k__cospi_E);
+      const __m128i u1 = _mm_madd_epi16(t0, k__cospi_F);
+      const __m128i u2 = _mm_madd_epi16(t1, k__cospi_G);
+      const __m128i u3 = _mm_madd_epi16(t1, k__cospi_H);
+      // Then add and right-shift to get back to 16-bit range
+      // but this combines the final right-shift as well to save operations
+      // This unusual rounding operations is to maintain bit-accurate
+      // compatibility with the c version of this function which has two
+      // rounding steps in a row.
+      const __m128i v0 = _mm_add_epi32(u0, k__DCT_CONST_ROUNDING2);
+      const __m128i v1 = _mm_add_epi32(u1, k__DCT_CONST_ROUNDING2);
+      const __m128i v2 = _mm_add_epi32(u2, k__DCT_CONST_ROUNDING2);
+      const __m128i v3 = _mm_add_epi32(u3, k__DCT_CONST_ROUNDING2);
+      const __m128i w0 = _mm_srai_epi32(v0, DCT_CONST_BITS2);
+      const __m128i w1 = _mm_srai_epi32(v1, DCT_CONST_BITS2);
+      const __m128i w2 = _mm_srai_epi32(v2, DCT_CONST_BITS2);
+      const __m128i w3 = _mm_srai_epi32(v3, DCT_CONST_BITS2);
+      // w0 = [o0 o4 o8 oC]
+      // w1 = [o2 o6 oA oE]
+      // w2 = [o1 o5 o9 oD]
+      // w3 = [o3 o7 oB oF]
+      // remember the o's are numbered according to the correct output location
+      const __m128i x0 = _mm_packs_epi32(w0, w1);
+      const __m128i x1 = _mm_packs_epi32(w2, w3);
+      {
+        // x0 = [o0 o4 o8 oC o2 o6 oA oE]
+        // x1 = [o1 o5 o9 oD o3 o7 oB oF]
+        const __m128i y0 = _mm_unpacklo_epi16(x0, x1);
+        const __m128i y1 = _mm_unpackhi_epi16(x0, x1);
+        // y0 = [o0 o1 o4 o5 o8 o9 oC oD]
+        // y1 = [o2 o3 o6 o7 oA oB oE oF]
+        *in0 = _mm_unpacklo_epi32(y0, y1);
+        // in0 = [o0 o1 o2 o3 o4 o5 o6 o7]
+        *in1 = _mm_unpackhi_epi32(y0, y1);
+        // in1 = [o8 o9 oA oB oC oD oE oF]
+      }
+    }
+  }
+}
+
+void FDCT4x4_2D(const int16_t *input, tran_low_t *output, int stride) {
+  // This 2D transform implements 4 vertical 1D transforms followed
+  // by 4 horizontal 1D transforms.  The multiplies and adds are as given
+  // by Chen, Smith and Fralick ('77).  The commands for moving the data
+  // around have been minimized by hand.
+  // For the purposes of the comments, the 16 inputs are referred to at i0
+  // through iF (in raster order), intermediate variables are a0, b0, c0
+  // through f, and correspond to the in-place computations mapped to input
+  // locations.  The outputs, o0 through oF are labeled according to the
+  // output locations.
+  __m128i in0, in1;
+  FDCT4x4_2D_HELPER(input, stride, &in0, &in1);
+
+  // Post-condition (v + 1) >> 2 is now incorporated into previous
+  // add and right-shift commands.  Only 2 store instructions needed
+  // because we are using the fact that 1/3 are stored just after 0/2.
+  storeu_output(&in0, output + 0 * 4);
+  storeu_output(&in1, output + 2 * 4);
+}
+
+void FDCT4x4_2D_LP(const int16_t *input, int16_t *output, int stride) {
+  __m128i in0, in1;
+  FDCT4x4_2D_HELPER(input, stride, &in0, &in1);
+  _mm_storeu_si128((__m128i *)(output + 0 * 4), in0);
+  _mm_storeu_si128((__m128i *)(output + 2 * 4), in1);
+}
+
 void FDCT8x8_2D(const int16_t *input, tran_low_t *output, int stride) {
   int pass;
   // Constants
diff --git a/aom_dsp/x86/fwd_txfm_sse2.c b/aom_dsp/x86/fwd_txfm_sse2.c
index 11c7d88..0e4fb80 100644
--- a/aom_dsp/x86/fwd_txfm_sse2.c
+++ b/aom_dsp/x86/fwd_txfm_sse2.c
@@ -18,8 +18,14 @@
 #include "aom_dsp/x86/fwd_txfm_sse2.h"
 
 #define DCT_HIGH_BIT_DEPTH 0
+#define FDCT4x4_2D_HELPER fdct4x4_helper
+#define FDCT4x4_2D aom_fdct4x4_sse2
+#define FDCT4x4_2D_LP aom_fdct4x4_lp_sse2
 #define FDCT8x8_2D aom_fdct8x8_sse2
 #include "aom_dsp/x86/fwd_txfm_impl_sse2.h"
+#undef FDCT4x4_2D_HELPER
+#undef FDCT4x4_2D
+#undef FDCT4x4_2D_LP
 #undef FDCT8x8_2D
 
 #if CONFIG_AV1_HIGHBITDEPTH
diff --git a/aom_dsp/x86/fwd_txfm_sse2.h b/aom_dsp/x86/fwd_txfm_sse2.h
index 260d8dd..ab3cd91 100644
--- a/aom_dsp/x86/fwd_txfm_sse2.h
+++ b/aom_dsp/x86/fwd_txfm_sse2.h
@@ -136,16 +136,21 @@
 }
 
 static INLINE void store_output(const __m128i *poutput, tran_low_t *dst_ptr) {
-  if (sizeof(tran_low_t) == 4) {
-    const __m128i zero = _mm_setzero_si128();
-    const __m128i sign_bits = _mm_cmplt_epi16(*poutput, zero);
-    __m128i out0 = _mm_unpacklo_epi16(*poutput, sign_bits);
-    __m128i out1 = _mm_unpackhi_epi16(*poutput, sign_bits);
-    _mm_store_si128((__m128i *)(dst_ptr), out0);
-    _mm_store_si128((__m128i *)(dst_ptr + 4), out1);
-  } else {
-    _mm_store_si128((__m128i *)(dst_ptr), *poutput);
-  }
+  const __m128i zero = _mm_setzero_si128();
+  const __m128i sign_bits = _mm_cmplt_epi16(*poutput, zero);
+  __m128i out0 = _mm_unpacklo_epi16(*poutput, sign_bits);
+  __m128i out1 = _mm_unpackhi_epi16(*poutput, sign_bits);
+  _mm_store_si128((__m128i *)(dst_ptr), out0);
+  _mm_store_si128((__m128i *)(dst_ptr + 4), out1);
+}
+
+static INLINE void storeu_output(const __m128i *poutput, tran_low_t *dst_ptr) {
+  const __m128i zero = _mm_setzero_si128();
+  const __m128i sign_bits = _mm_cmplt_epi16(*poutput, zero);
+  __m128i out0 = _mm_unpacklo_epi16(*poutput, sign_bits);
+  __m128i out1 = _mm_unpackhi_epi16(*poutput, sign_bits);
+  _mm_storeu_si128((__m128i *)(dst_ptr), out0);
+  _mm_storeu_si128((__m128i *)(dst_ptr + 4), out1);
 }
 
 #ifdef __cplusplus
diff --git a/aom_dsp/x86/txfm_common_sse2.h b/aom_dsp/x86/txfm_common_sse2.h
index ed82eee..9c99eb9 100644
--- a/aom_dsp/x86/txfm_common_sse2.h
+++ b/aom_dsp/x86/txfm_common_sse2.h
@@ -26,4 +26,8 @@
   return _mm_shuffle_epi32(b, 0x4e);
 }
 
+#define octa_set_epi16(a, b, c, d, e, f, g, h)                           \
+  _mm_setr_epi16((int16_t)(a), (int16_t)(b), (int16_t)(c), (int16_t)(d), \
+                 (int16_t)(e), (int16_t)(f), (int16_t)(g), (int16_t)(h))
+
 #endif  // AOM_AOM_DSP_X86_TXFM_COMMON_SSE2_H_
diff --git a/test/fdct4x4_test.cc b/test/fdct4x4_test.cc
new file mode 100644
index 0000000..9679777
--- /dev/null
+++ b/test/fdct4x4_test.cc
@@ -0,0 +1,123 @@
+/*
+ * Copyright (c) 2020, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#include <math.h>
+#include <stdlib.h>
+#include <string.h>
+
+#include "aom_dsp/aom_dsp_common.h"
+#include "third_party/googletest/src/googletest/include/gtest/gtest.h"
+
+#include "config/av1_rtcd.h"
+#include "config/aom_dsp_rtcd.h"
+#include "test/acm_random.h"
+#include "test/clear_system_state.h"
+#include "test/register_state_check.h"
+#include "test/transform_test_base.h"
+#include "test/util.h"
+#include "av1/common/entropy.h"
+#include "aom/aom_codec.h"
+#include "aom/aom_integer.h"
+#include "aom_ports/mem.h"
+
+using libaom_test::ACMRandom;
+
+namespace {
+
+template <typename OutputType>
+using FdctFunc = void (*)(const int16_t *in, OutputType *out, int stride);
+
+template <typename OutputType>
+using FhtFunc = void (*)(const int16_t *in, OutputType *out, int stride,
+                         TxfmParam *txfm_param);
+
+template <typename OutputType>
+using Fdct4x4Param = ::testing::tuple<FdctFunc<OutputType>, FhtFunc<OutputType>,
+                                      aom_bit_depth_t, int>;
+
+#if HAVE_NEON || HAVE_SSE2
+void fdct4x4_ref(const int16_t *in, tran_low_t *out, int stride,
+                 TxfmParam * /*txfm_param*/) {
+  aom_fdct4x4_c(in, out, stride);
+}
+
+void fdct4x4_lp_ref(const int16_t *in, int16_t *out, int stride,
+                    TxfmParam * /*txfm_param*/) {
+  aom_fdct4x4_lp_c(in, out, stride);
+}
+#endif
+
+template <typename OutputType>
+class Trans4x4FDCT : public libaom_test::TransformTestBase<OutputType>,
+                     public ::testing::TestWithParam<Fdct4x4Param<OutputType>> {
+ public:
+  virtual ~Trans4x4FDCT() {}
+
+  using TxfmBaseOutType = libaom_test::TransformTestBase<OutputType>;
+  virtual void SetUp() {
+    fwd_txfm_ = ::testing::get<0>(this->GetParam());
+    TxfmBaseOutType::pitch_ = 4;
+    TxfmBaseOutType::height_ = 4;
+    TxfmBaseOutType::fwd_txfm_ref = ::testing::get<1>(this->GetParam());
+    TxfmBaseOutType::bit_depth_ = ::testing::get<2>(this->GetParam());
+    TxfmBaseOutType::mask_ = (1 << TxfmBaseOutType::bit_depth_) - 1;
+    TxfmBaseOutType::num_coeffs_ = ::testing::get<3>(this->GetParam());
+  }
+  virtual void TearDown() { libaom_test::ClearSystemState(); }
+
+ protected:
+  void RunFwdTxfm(const int16_t *in, OutputType *out, int stride) {
+    fwd_txfm_(in, out, stride);
+  }
+
+  void RunInvTxfm(const OutputType *out, uint8_t *dst, int stride) {
+    (void)out;
+    (void)dst;
+    (void)stride;
+  }
+
+  FdctFunc<OutputType> fwd_txfm_;
+};
+
+using Trans4x4FDCTTranLow = Trans4x4FDCT<tran_low_t>;
+TEST_P(Trans4x4FDCTTranLow, CoeffCheck) { RunCoeffCheck(); }
+TEST_P(Trans4x4FDCTTranLow, MemCheck) { RunMemCheck(); }
+
+using Trans4x4FDCTInt16 = Trans4x4FDCT<int16_t>;
+TEST_P(Trans4x4FDCTInt16, CoeffCheck) { RunCoeffCheck(); }
+TEST_P(Trans4x4FDCTInt16, MemCheck) { RunMemCheck(); }
+
+using ::testing::make_tuple;
+
+#if HAVE_NEON
+INSTANTIATE_TEST_CASE_P(NEON, Trans4x4FDCTTranLow,
+                        ::testing::Values(make_tuple(&aom_fdct4x4_neon,
+                                                     &fdct4x4_ref, AOM_BITS_8,
+                                                     16)));
+
+INSTANTIATE_TEST_CASE_P(NEON, Trans4x4FDCTInt16,
+                        ::testing::Values(make_tuple(&aom_fdct4x4_lp_neon,
+                                                     &fdct4x4_lp_ref,
+                                                     AOM_BITS_8, 16)));
+#endif
+
+#if HAVE_SSE2
+INSTANTIATE_TEST_CASE_P(SSE2, Trans4x4FDCTTranLow,
+                        ::testing::Values(make_tuple(&aom_fdct4x4_sse2,
+                                                     &fdct4x4_ref, AOM_BITS_8,
+                                                     16)));
+
+INSTANTIATE_TEST_CASE_P(SSE2, Trans4x4FDCTInt16,
+                        ::testing::Values(make_tuple(&aom_fdct4x4_lp_sse2,
+                                                     &fdct4x4_lp_ref,
+                                                     AOM_BITS_8, 16)));
+#endif
+}  // namespace
diff --git a/test/fwht4x4_test.cc b/test/fwht4x4_test.cc
index c8d98c5..251acb7 100644
--- a/test/fwht4x4_test.cc
+++ b/test/fwht4x4_test.cc
@@ -13,6 +13,7 @@
 #include <stdlib.h>
 #include <string.h>
 
+#include "aom_dsp/aom_dsp_common.h"
 #include "third_party/googletest/src/googletest/include/gtest/gtest.h"
 
 #include "config/av1_rtcd.h"
@@ -51,7 +52,7 @@
   av1_highbd_iwht4x4_16_add_c(in, out, stride, 12);
 }
 
-class Trans4x4WHT : public libaom_test::TransformTestBase,
+class Trans4x4WHT : public libaom_test::TransformTestBase<tran_low_t>,
                     public ::testing::TestWithParam<Dct4x4Param> {
  public:
   virtual ~Trans4x4WHT() {}
diff --git a/test/test.cmake b/test/test.cmake
index 5ae2ab6..b172e64 100644
--- a/test/test.cmake
+++ b/test/test.cmake
@@ -207,6 +207,7 @@
               "${AOM_ROOT}/test/error_block_test.cc"
               "${AOM_ROOT}/test/fft_test.cc"
               "${AOM_ROOT}/test/fwht4x4_test.cc"
+              "${AOM_ROOT}/test/fdct4x4_test.cc"
               "${AOM_ROOT}/test/hadamard_test.cc"
               "${AOM_ROOT}/test/horver_correlation_test.cc"
               "${AOM_ROOT}/test/masked_sad_test.cc"
diff --git a/test/transform_test_base.h b/test/transform_test_base.h
index 8ebcf5f..68f5cc7 100644
--- a/test/transform_test_base.h
+++ b/test/transform_test_base.h
@@ -29,20 +29,23 @@
 //   to a aom header file.
 const int kDctMaxValue = 16384;
 
-typedef void (*FhtFunc)(const int16_t *in, tran_low_t *out, int stride,
-                        TxfmParam *txfm_param);
+template <typename OutputType>
+using FhtFunc = void (*)(const int16_t *in, OutputType *out, int stride,
+                         TxfmParam *txfm_param);
 
-typedef void (*IhtFunc)(const tran_low_t *in, uint8_t *out, int stride,
-                        const TxfmParam *txfm_param);
+template <typename OutputType>
+using IhtFunc = void (*)(const tran_low_t *in, uint8_t *out, int stride,
+                         const TxfmParam *txfm_param);
 
+template <typename OutType>
 class TransformTestBase {
  public:
   virtual ~TransformTestBase() {}
 
  protected:
-  virtual void RunFwdTxfm(const int16_t *in, tran_low_t *out, int stride) = 0;
+  virtual void RunFwdTxfm(const int16_t *in, OutType *out, int stride) = 0;
 
-  virtual void RunInvTxfm(const tran_low_t *out, uint8_t *dst, int stride) = 0;
+  virtual void RunInvTxfm(const OutType *out, uint8_t *dst, int stride) = 0;
 
   void RunAccuracyCheck(uint32_t ref_max_error, double ref_avg_error) {
     ACMRandom rnd(ACMRandom::DeterministicSeed());
@@ -52,8 +55,8 @@
 
     int16_t *test_input_block = reinterpret_cast<int16_t *>(
         aom_memalign(16, sizeof(int16_t) * num_coeffs_));
-    tran_low_t *test_temp_block = reinterpret_cast<tran_low_t *>(
-        aom_memalign(16, sizeof(tran_low_t) * num_coeffs_));
+    OutType *test_temp_block = reinterpret_cast<OutType *>(
+        aom_memalign(16, sizeof(test_temp_block[0]) * num_coeffs_));
     uint8_t *dst = reinterpret_cast<uint8_t *>(
         aom_memalign(16, sizeof(uint8_t) * num_coeffs_));
     uint8_t *src = reinterpret_cast<uint8_t *>(
@@ -123,10 +126,10 @@
 
     int16_t *input_block = reinterpret_cast<int16_t *>(
         aom_memalign(16, sizeof(int16_t) * stride * height_));
-    tran_low_t *output_ref_block = reinterpret_cast<tran_low_t *>(
-        aom_memalign(16, sizeof(tran_low_t) * num_coeffs_));
-    tran_low_t *output_block = reinterpret_cast<tran_low_t *>(
-        aom_memalign(16, sizeof(tran_low_t) * num_coeffs_));
+    OutType *output_ref_block = reinterpret_cast<OutType *>(
+        aom_memalign(16, sizeof(output_ref_block[0]) * num_coeffs_));
+    OutType *output_block = reinterpret_cast<OutType *>(
+        aom_memalign(16, sizeof(output_block[0]) * num_coeffs_));
 
     for (int i = 0; i < count_test_block; ++i) {
       int j, k;
@@ -172,8 +175,8 @@
 
     int16_t *input_block = reinterpret_cast<int16_t *>(
         aom_memalign(16, sizeof(int16_t) * num_coeffs_));
-    tran_low_t *trans_block = reinterpret_cast<tran_low_t *>(
-        aom_memalign(16, sizeof(tran_low_t) * num_coeffs_));
+    OutType *trans_block = reinterpret_cast<OutType *>(
+        aom_memalign(16, sizeof(trans_block[0]) * num_coeffs_));
     uint8_t *output_block = reinterpret_cast<uint8_t *>(
         aom_memalign(16, sizeof(uint8_t) * stride * height_));
     uint8_t *output_ref_block = reinterpret_cast<uint8_t *>(
@@ -218,10 +221,10 @@
 
     int16_t *input_extreme_block = reinterpret_cast<int16_t *>(
         aom_memalign(16, sizeof(int16_t) * num_coeffs_));
-    tran_low_t *output_ref_block = reinterpret_cast<tran_low_t *>(
-        aom_memalign(16, sizeof(tran_low_t) * num_coeffs_));
-    tran_low_t *output_block = reinterpret_cast<tran_low_t *>(
-        aom_memalign(16, sizeof(tran_low_t) * num_coeffs_));
+    OutType *output_ref_block = reinterpret_cast<OutType *>(
+        aom_memalign(16, sizeof(output_ref_block[0]) * num_coeffs_));
+    OutType *output_block = reinterpret_cast<OutType *>(
+        aom_memalign(16, sizeof(output_block[0]) * num_coeffs_));
 
     for (int i = 0; i < count_test_block; ++i) {
       // Initialize a test block with input range [-mask_, mask_].
@@ -260,8 +263,8 @@
 
     int16_t *in = reinterpret_cast<int16_t *>(
         aom_memalign(16, sizeof(int16_t) * num_coeffs_));
-    tran_low_t *coeff = reinterpret_cast<tran_low_t *>(
-        aom_memalign(16, sizeof(tran_low_t) * num_coeffs_));
+    OutType *coeff = reinterpret_cast<OutType *>(
+        aom_memalign(16, sizeof(coeff[0]) * num_coeffs_));
     uint8_t *dst = reinterpret_cast<uint8_t *>(
         aom_memalign(16, sizeof(uint8_t) * num_coeffs_));
     uint8_t *src = reinterpret_cast<uint8_t *>(
@@ -313,8 +316,8 @@
 
   int pitch_;
   int height_;
-  FhtFunc fwd_txfm_ref;
-  IhtFunc inv_txfm_ref;
+  FhtFunc<OutType> fwd_txfm_ref;
+  IhtFunc<OutType> inv_txfm_ref;
   aom_bit_depth_t bit_depth_;
   int mask_;
   int num_coeffs_;