Add neon/sse2 optimization for fwd txfm 4x4 Add unit test to verify C/SSE2/NEON are bitexact. Change-Id: I24bdee1387622510436ed149e7b61c937488a735
diff --git a/aom_dsp/aom_dsp_rtcd_defs.pl b/aom_dsp/aom_dsp_rtcd_defs.pl index c84a380..b2b4b19 100755 --- a/aom_dsp/aom_dsp_rtcd_defs.pl +++ b/aom_dsp/aom_dsp_rtcd_defs.pl
@@ -481,8 +481,10 @@ # if (aom_config("CONFIG_AV1_ENCODER") eq "yes"){ add_proto qw/void aom_fdct4x4/, "const int16_t *input, tran_low_t *output, int stride"; + specialize qw/aom_fdct4x4 neon sse2/; add_proto qw/void aom_fdct4x4_lp/, "const int16_t *input, int16_t *output, int stride"; + specialize qw/aom_fdct4x4_lp neon sse2/; add_proto qw/void aom_fdct8x8/, "const int16_t *input, tran_low_t *output, int stride"; specialize qw/aom_fdct8x8 neon sse2/, "$ssse3_x86_64";
diff --git a/aom_dsp/arm/fwd_txfm_neon.c b/aom_dsp/arm/fwd_txfm_neon.c index e4300c9..ce93523 100644 --- a/aom_dsp/arm/fwd_txfm_neon.c +++ b/aom_dsp/arm/fwd_txfm_neon.c
@@ -14,9 +14,103 @@ #include "config/aom_config.h" #include "aom_dsp/txfm_common.h" +#include "av1/common/arm/mem_neon.h" +#include "av1/common/arm/transpose_neon.h" + +static void aom_fdct4x4_helper(const int16_t *input, int stride, + int16x4_t *input_0, int16x4_t *input_1, + int16x4_t *input_2, int16x4_t *input_3) { + *input_0 = vshl_n_s16(vld1_s16(input + 0 * stride), 4); + *input_1 = vshl_n_s16(vld1_s16(input + 1 * stride), 4); + *input_2 = vshl_n_s16(vld1_s16(input + 2 * stride), 4); + *input_3 = vshl_n_s16(vld1_s16(input + 3 * stride), 4); + // If the very first value != 0, then add 1. + if (input[0] != 0) { + const int16x4_t one = vreinterpret_s16_s64(vdup_n_s64(1)); + *input_0 = vadd_s16(*input_0, one); + } + + for (int i = 0; i < 2; ++i) { + const int16x8_t input_01 = vcombine_s16(*input_0, *input_1); + const int16x8_t input_32 = vcombine_s16(*input_3, *input_2); + + // in_0 +/- in_3, in_1 +/- in_2 + const int16x8_t s_01 = vaddq_s16(input_01, input_32); + const int16x8_t s_32 = vsubq_s16(input_01, input_32); + + // step_0 +/- step_1, step_2 +/- step_3 + const int16x4_t s_0 = vget_low_s16(s_01); + const int16x4_t s_1 = vget_high_s16(s_01); + const int16x4_t s_2 = vget_high_s16(s_32); + const int16x4_t s_3 = vget_low_s16(s_32); + + // (s_0 +/- s_1) * cospi_16_64 + // Must expand all elements to s32. See 'needs32' comment in fwd_txfm.c. + const int32x4_t s_0_p_s_1 = vaddl_s16(s_0, s_1); + const int32x4_t s_0_m_s_1 = vsubl_s16(s_0, s_1); + const int32x4_t temp1 = vmulq_n_s32(s_0_p_s_1, cospi_16_64); + const int32x4_t temp2 = vmulq_n_s32(s_0_m_s_1, cospi_16_64); + + // fdct_round_shift + int16x4_t out_0 = vrshrn_n_s32(temp1, DCT_CONST_BITS); + int16x4_t out_2 = vrshrn_n_s32(temp2, DCT_CONST_BITS); + + // s_3 * cospi_8_64 + s_2 * cospi_24_64 + // s_3 * cospi_24_64 - s_2 * cospi_8_64 + const int32x4_t s_3_cospi_8_64 = vmull_n_s16(s_3, cospi_8_64); + const int32x4_t s_3_cospi_24_64 = vmull_n_s16(s_3, cospi_24_64); + + const int32x4_t temp3 = vmlal_n_s16(s_3_cospi_8_64, s_2, cospi_24_64); + const int32x4_t temp4 = vmlsl_n_s16(s_3_cospi_24_64, s_2, cospi_8_64); + + // fdct_round_shift + int16x4_t out_1 = vrshrn_n_s32(temp3, DCT_CONST_BITS); + int16x4_t out_3 = vrshrn_n_s32(temp4, DCT_CONST_BITS); + + transpose_s16_4x4d(&out_0, &out_1, &out_2, &out_3); + + *input_0 = out_0; + *input_1 = out_1; + *input_2 = out_2; + *input_3 = out_3; + } +} + +void aom_fdct4x4_neon(const int16_t *input, tran_low_t *final_output, + int stride) { + // input[M * stride] * 16 + int16x4_t input_0, input_1, input_2, input_3; + + aom_fdct4x4_helper(input, stride, &input_0, &input_1, &input_2, &input_3); + + // Not quite a rounding shift. Only add 1 despite shifting by 2. + const int16x8_t one = vdupq_n_s16(1); + int16x8_t out_01 = vcombine_s16(input_0, input_1); + int16x8_t out_23 = vcombine_s16(input_2, input_3); + out_01 = vshrq_n_s16(vaddq_s16(out_01, one), 2); + out_23 = vshrq_n_s16(vaddq_s16(out_23, one), 2); + store_s16q_to_tran_low(final_output + 0 * 8, out_01); + store_s16q_to_tran_low(final_output + 1 * 8, out_23); +} + +void aom_fdct4x4_lp_neon(const int16_t *input, int16_t *final_output, + int stride) { + // input[M * stride] * 16 + int16x4_t input_0, input_1, input_2, input_3; + + aom_fdct4x4_helper(input, stride, &input_0, &input_1, &input_2, &input_3); + + // Not quite a rounding shift. Only add 1 despite shifting by 2. + const int16x8_t one = vdupq_n_s16(1); + int16x8_t out_01 = vcombine_s16(input_0, input_1); + int16x8_t out_23 = vcombine_s16(input_2, input_3); + out_01 = vshrq_n_s16(vaddq_s16(out_01, one), 2); + out_23 = vshrq_n_s16(vaddq_s16(out_23, one), 2); + vst1q_s16(final_output + 0 * 8, out_01); + vst1q_s16(final_output + 1 * 8, out_23); +} void aom_fdct8x8_neon(const int16_t *input, int16_t *final_output, int stride) { - int i; // stage 1 int16x8_t input_0 = vshlq_n_s16(vld1q_s16(&input[0 * stride]), 2); int16x8_t input_1 = vshlq_n_s16(vld1q_s16(&input[1 * stride]), 2); @@ -26,7 +120,7 @@ int16x8_t input_5 = vshlq_n_s16(vld1q_s16(&input[5 * stride]), 2); int16x8_t input_6 = vshlq_n_s16(vld1q_s16(&input[6 * stride]), 2); int16x8_t input_7 = vshlq_n_s16(vld1q_s16(&input[7 * stride]), 2); - for (i = 0; i < 2; ++i) { + for (int i = 0; i < 2; ++i) { int16x8_t out_0, out_1, out_2, out_3, out_4, out_5, out_6, out_7; const int16x8_t v_s0 = vaddq_s16(input_0, input_7); const int16x8_t v_s1 = vaddq_s16(input_1, input_6);
diff --git a/aom_dsp/x86/fwd_txfm_impl_sse2.h b/aom_dsp/x86/fwd_txfm_impl_sse2.h index 1e3d13e..89fe189 100644 --- a/aom_dsp/x86/fwd_txfm_impl_sse2.h +++ b/aom_dsp/x86/fwd_txfm_impl_sse2.h
@@ -30,6 +30,206 @@ #define SUB_EPI16 _mm_sub_epi16 #endif +static void FDCT4x4_2D_HELPER(const int16_t *input, int stride, __m128i *in0, + __m128i *in1) { + // Constants + // These are the coefficients used for the multiplies. + // In the comments, pN means cos(N pi /64) and mN is -cos(N pi /64), + // where cospi_N_64 = cos(N pi /64) + const __m128i k__cospi_A = + octa_set_epi16(cospi_16_64, cospi_16_64, cospi_16_64, cospi_16_64, + cospi_16_64, -cospi_16_64, cospi_16_64, -cospi_16_64); + const __m128i k__cospi_B = + octa_set_epi16(cospi_16_64, -cospi_16_64, cospi_16_64, -cospi_16_64, + cospi_16_64, cospi_16_64, cospi_16_64, cospi_16_64); + const __m128i k__cospi_C = + octa_set_epi16(cospi_8_64, cospi_24_64, cospi_8_64, cospi_24_64, + cospi_24_64, -cospi_8_64, cospi_24_64, -cospi_8_64); + const __m128i k__cospi_D = + octa_set_epi16(cospi_24_64, -cospi_8_64, cospi_24_64, -cospi_8_64, + cospi_8_64, cospi_24_64, cospi_8_64, cospi_24_64); + const __m128i k__cospi_E = + octa_set_epi16(cospi_16_64, cospi_16_64, cospi_16_64, cospi_16_64, + cospi_16_64, cospi_16_64, cospi_16_64, cospi_16_64); + const __m128i k__cospi_F = + octa_set_epi16(cospi_16_64, -cospi_16_64, cospi_16_64, -cospi_16_64, + cospi_16_64, -cospi_16_64, cospi_16_64, -cospi_16_64); + const __m128i k__cospi_G = + octa_set_epi16(cospi_8_64, cospi_24_64, cospi_8_64, cospi_24_64, + -cospi_8_64, -cospi_24_64, -cospi_8_64, -cospi_24_64); + const __m128i k__cospi_H = + octa_set_epi16(cospi_24_64, -cospi_8_64, cospi_24_64, -cospi_8_64, + -cospi_24_64, cospi_8_64, -cospi_24_64, cospi_8_64); + + const __m128i k__DCT_CONST_ROUNDING = _mm_set1_epi32(DCT_CONST_ROUNDING); + // This second rounding constant saves doing some extra adds at the end + const __m128i k__DCT_CONST_ROUNDING2 = + _mm_set1_epi32(DCT_CONST_ROUNDING + (DCT_CONST_ROUNDING << 1)); + const int DCT_CONST_BITS2 = DCT_CONST_BITS + 2; + const __m128i k__nonzero_bias_a = _mm_setr_epi16(0, 1, 1, 1, 1, 1, 1, 1); + const __m128i k__nonzero_bias_b = _mm_setr_epi16(1, 0, 0, 0, 0, 0, 0, 0); + + // Load inputs. + *in0 = _mm_loadl_epi64((const __m128i *)(input + 0 * stride)); + *in1 = _mm_loadl_epi64((const __m128i *)(input + 1 * stride)); + *in1 = _mm_unpacklo_epi64( + *in1, _mm_loadl_epi64((const __m128i *)(input + 2 * stride))); + *in0 = _mm_unpacklo_epi64( + *in0, _mm_loadl_epi64((const __m128i *)(input + 3 * stride))); + // in0 = [i0 i1 i2 i3 iC iD iE iF] + // in1 = [i4 i5 i6 i7 i8 i9 iA iB] + // multiply by 16 to give some extra precision + *in0 = _mm_slli_epi16(*in0, 4); + *in1 = _mm_slli_epi16(*in1, 4); + // if (i == 0 && input[0]) input[0] += 1; + // add 1 to the upper left pixel if it is non-zero, which helps reduce + // the round-trip error + { + // The mask will only contain whether the first value is zero, all + // other comparison will fail as something shifted by 4 (above << 4) + // can never be equal to one. To increment in the non-zero case, we + // add the mask and one for the first element: + // - if zero, mask = -1, v = v - 1 + 1 = v + // - if non-zero, mask = 0, v = v + 0 + 1 = v + 1 + __m128i mask = _mm_cmpeq_epi16(*in0, k__nonzero_bias_a); + *in0 = _mm_add_epi16(*in0, mask); + *in0 = _mm_add_epi16(*in0, k__nonzero_bias_b); + } + // There are 4 total stages, alternating between an add/subtract stage + // followed by an multiply-and-add stage. + { + // Stage 1: Add/subtract + + // in0 = [i0 i1 i2 i3 iC iD iE iF] + // in1 = [i4 i5 i6 i7 i8 i9 iA iB] + const __m128i r0 = _mm_unpacklo_epi16(*in0, *in1); + const __m128i r1 = _mm_unpackhi_epi16(*in0, *in1); + // r0 = [i0 i4 i1 i5 i2 i6 i3 i7] + // r1 = [iC i8 iD i9 iE iA iF iB] + const __m128i r2 = _mm_shuffle_epi32(r0, 0xB4); + const __m128i r3 = _mm_shuffle_epi32(r1, 0xB4); + // r2 = [i0 i4 i1 i5 i3 i7 i2 i6] + // r3 = [iC i8 iD i9 iF iB iE iA] + + const __m128i t0 = _mm_add_epi16(r2, r3); + const __m128i t1 = _mm_sub_epi16(r2, r3); + // t0 = [a0 a4 a1 a5 a3 a7 a2 a6] + // t1 = [aC a8 aD a9 aF aB aE aA] + + // Stage 2: multiply by constants (which gets us into 32 bits). + // The constants needed here are: + // k__cospi_A = [p16 p16 p16 p16 p16 m16 p16 m16] + // k__cospi_B = [p16 m16 p16 m16 p16 p16 p16 p16] + // k__cospi_C = [p08 p24 p08 p24 p24 m08 p24 m08] + // k__cospi_D = [p24 m08 p24 m08 p08 p24 p08 p24] + const __m128i u0 = _mm_madd_epi16(t0, k__cospi_A); + const __m128i u2 = _mm_madd_epi16(t0, k__cospi_B); + const __m128i u1 = _mm_madd_epi16(t1, k__cospi_C); + const __m128i u3 = _mm_madd_epi16(t1, k__cospi_D); + // Then add and right-shift to get back to 16-bit range + const __m128i v0 = _mm_add_epi32(u0, k__DCT_CONST_ROUNDING); + const __m128i v1 = _mm_add_epi32(u1, k__DCT_CONST_ROUNDING); + const __m128i v2 = _mm_add_epi32(u2, k__DCT_CONST_ROUNDING); + const __m128i v3 = _mm_add_epi32(u3, k__DCT_CONST_ROUNDING); + const __m128i w0 = _mm_srai_epi32(v0, DCT_CONST_BITS); + const __m128i w1 = _mm_srai_epi32(v1, DCT_CONST_BITS); + const __m128i w2 = _mm_srai_epi32(v2, DCT_CONST_BITS); + const __m128i w3 = _mm_srai_epi32(v3, DCT_CONST_BITS); + // w0 = [b0 b1 b7 b6] + // w1 = [b8 b9 bF bE] + // w2 = [b4 b5 b3 b2] + // w3 = [bC bD bB bA] + const __m128i x0 = _mm_packs_epi32(w0, w1); + const __m128i x1 = _mm_packs_epi32(w2, w3); + + // x0 = [b0 b1 b7 b6 b8 b9 bF bE] + // x1 = [b4 b5 b3 b2 bC bD bB bA] + *in0 = _mm_shuffle_epi32(x0, 0xD8); + *in1 = _mm_shuffle_epi32(x1, 0x8D); + // in0 = [b0 b1 b8 b9 b7 b6 bF bE] + // in1 = [b3 b2 bB bA b4 b5 bC bD] + } + { + // vertical DCTs finished. Now we do the horizontal DCTs. + // Stage 3: Add/subtract + + const __m128i t0 = ADD_EPI16(*in0, *in1); + const __m128i t1 = SUB_EPI16(*in0, *in1); + + // Stage 4: multiply by constants (which gets us into 32 bits). + { + // The constants needed here are: + // k__cospi_E = [p16 p16 p16 p16 p16 p16 p16 p16] + // k__cospi_F = [p16 m16 p16 m16 p16 m16 p16 m16] + // k__cospi_G = [p08 p24 p08 p24 m08 m24 m08 m24] + // k__cospi_H = [p24 m08 p24 m08 m24 p08 m24 p08] + const __m128i u0 = _mm_madd_epi16(t0, k__cospi_E); + const __m128i u1 = _mm_madd_epi16(t0, k__cospi_F); + const __m128i u2 = _mm_madd_epi16(t1, k__cospi_G); + const __m128i u3 = _mm_madd_epi16(t1, k__cospi_H); + // Then add and right-shift to get back to 16-bit range + // but this combines the final right-shift as well to save operations + // This unusual rounding operations is to maintain bit-accurate + // compatibility with the c version of this function which has two + // rounding steps in a row. + const __m128i v0 = _mm_add_epi32(u0, k__DCT_CONST_ROUNDING2); + const __m128i v1 = _mm_add_epi32(u1, k__DCT_CONST_ROUNDING2); + const __m128i v2 = _mm_add_epi32(u2, k__DCT_CONST_ROUNDING2); + const __m128i v3 = _mm_add_epi32(u3, k__DCT_CONST_ROUNDING2); + const __m128i w0 = _mm_srai_epi32(v0, DCT_CONST_BITS2); + const __m128i w1 = _mm_srai_epi32(v1, DCT_CONST_BITS2); + const __m128i w2 = _mm_srai_epi32(v2, DCT_CONST_BITS2); + const __m128i w3 = _mm_srai_epi32(v3, DCT_CONST_BITS2); + // w0 = [o0 o4 o8 oC] + // w1 = [o2 o6 oA oE] + // w2 = [o1 o5 o9 oD] + // w3 = [o3 o7 oB oF] + // remember the o's are numbered according to the correct output location + const __m128i x0 = _mm_packs_epi32(w0, w1); + const __m128i x1 = _mm_packs_epi32(w2, w3); + { + // x0 = [o0 o4 o8 oC o2 o6 oA oE] + // x1 = [o1 o5 o9 oD o3 o7 oB oF] + const __m128i y0 = _mm_unpacklo_epi16(x0, x1); + const __m128i y1 = _mm_unpackhi_epi16(x0, x1); + // y0 = [o0 o1 o4 o5 o8 o9 oC oD] + // y1 = [o2 o3 o6 o7 oA oB oE oF] + *in0 = _mm_unpacklo_epi32(y0, y1); + // in0 = [o0 o1 o2 o3 o4 o5 o6 o7] + *in1 = _mm_unpackhi_epi32(y0, y1); + // in1 = [o8 o9 oA oB oC oD oE oF] + } + } + } +} + +void FDCT4x4_2D(const int16_t *input, tran_low_t *output, int stride) { + // This 2D transform implements 4 vertical 1D transforms followed + // by 4 horizontal 1D transforms. The multiplies and adds are as given + // by Chen, Smith and Fralick ('77). The commands for moving the data + // around have been minimized by hand. + // For the purposes of the comments, the 16 inputs are referred to at i0 + // through iF (in raster order), intermediate variables are a0, b0, c0 + // through f, and correspond to the in-place computations mapped to input + // locations. The outputs, o0 through oF are labeled according to the + // output locations. + __m128i in0, in1; + FDCT4x4_2D_HELPER(input, stride, &in0, &in1); + + // Post-condition (v + 1) >> 2 is now incorporated into previous + // add and right-shift commands. Only 2 store instructions needed + // because we are using the fact that 1/3 are stored just after 0/2. + storeu_output(&in0, output + 0 * 4); + storeu_output(&in1, output + 2 * 4); +} + +void FDCT4x4_2D_LP(const int16_t *input, int16_t *output, int stride) { + __m128i in0, in1; + FDCT4x4_2D_HELPER(input, stride, &in0, &in1); + _mm_storeu_si128((__m128i *)(output + 0 * 4), in0); + _mm_storeu_si128((__m128i *)(output + 2 * 4), in1); +} + void FDCT8x8_2D(const int16_t *input, tran_low_t *output, int stride) { int pass; // Constants
diff --git a/aom_dsp/x86/fwd_txfm_sse2.c b/aom_dsp/x86/fwd_txfm_sse2.c index 11c7d88..0e4fb80 100644 --- a/aom_dsp/x86/fwd_txfm_sse2.c +++ b/aom_dsp/x86/fwd_txfm_sse2.c
@@ -18,8 +18,14 @@ #include "aom_dsp/x86/fwd_txfm_sse2.h" #define DCT_HIGH_BIT_DEPTH 0 +#define FDCT4x4_2D_HELPER fdct4x4_helper +#define FDCT4x4_2D aom_fdct4x4_sse2 +#define FDCT4x4_2D_LP aom_fdct4x4_lp_sse2 #define FDCT8x8_2D aom_fdct8x8_sse2 #include "aom_dsp/x86/fwd_txfm_impl_sse2.h" +#undef FDCT4x4_2D_HELPER +#undef FDCT4x4_2D +#undef FDCT4x4_2D_LP #undef FDCT8x8_2D #if CONFIG_AV1_HIGHBITDEPTH
diff --git a/aom_dsp/x86/fwd_txfm_sse2.h b/aom_dsp/x86/fwd_txfm_sse2.h index 260d8dd..ab3cd91 100644 --- a/aom_dsp/x86/fwd_txfm_sse2.h +++ b/aom_dsp/x86/fwd_txfm_sse2.h
@@ -136,16 +136,21 @@ } static INLINE void store_output(const __m128i *poutput, tran_low_t *dst_ptr) { - if (sizeof(tran_low_t) == 4) { - const __m128i zero = _mm_setzero_si128(); - const __m128i sign_bits = _mm_cmplt_epi16(*poutput, zero); - __m128i out0 = _mm_unpacklo_epi16(*poutput, sign_bits); - __m128i out1 = _mm_unpackhi_epi16(*poutput, sign_bits); - _mm_store_si128((__m128i *)(dst_ptr), out0); - _mm_store_si128((__m128i *)(dst_ptr + 4), out1); - } else { - _mm_store_si128((__m128i *)(dst_ptr), *poutput); - } + const __m128i zero = _mm_setzero_si128(); + const __m128i sign_bits = _mm_cmplt_epi16(*poutput, zero); + __m128i out0 = _mm_unpacklo_epi16(*poutput, sign_bits); + __m128i out1 = _mm_unpackhi_epi16(*poutput, sign_bits); + _mm_store_si128((__m128i *)(dst_ptr), out0); + _mm_store_si128((__m128i *)(dst_ptr + 4), out1); +} + +static INLINE void storeu_output(const __m128i *poutput, tran_low_t *dst_ptr) { + const __m128i zero = _mm_setzero_si128(); + const __m128i sign_bits = _mm_cmplt_epi16(*poutput, zero); + __m128i out0 = _mm_unpacklo_epi16(*poutput, sign_bits); + __m128i out1 = _mm_unpackhi_epi16(*poutput, sign_bits); + _mm_storeu_si128((__m128i *)(dst_ptr), out0); + _mm_storeu_si128((__m128i *)(dst_ptr + 4), out1); } #ifdef __cplusplus
diff --git a/aom_dsp/x86/txfm_common_sse2.h b/aom_dsp/x86/txfm_common_sse2.h index ed82eee..9c99eb9 100644 --- a/aom_dsp/x86/txfm_common_sse2.h +++ b/aom_dsp/x86/txfm_common_sse2.h
@@ -26,4 +26,8 @@ return _mm_shuffle_epi32(b, 0x4e); } +#define octa_set_epi16(a, b, c, d, e, f, g, h) \ + _mm_setr_epi16((int16_t)(a), (int16_t)(b), (int16_t)(c), (int16_t)(d), \ + (int16_t)(e), (int16_t)(f), (int16_t)(g), (int16_t)(h)) + #endif // AOM_AOM_DSP_X86_TXFM_COMMON_SSE2_H_
diff --git a/test/fdct4x4_test.cc b/test/fdct4x4_test.cc new file mode 100644 index 0000000..9679777 --- /dev/null +++ b/test/fdct4x4_test.cc
@@ -0,0 +1,123 @@ +/* + * Copyright (c) 2020, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include <math.h> +#include <stdlib.h> +#include <string.h> + +#include "aom_dsp/aom_dsp_common.h" +#include "third_party/googletest/src/googletest/include/gtest/gtest.h" + +#include "config/av1_rtcd.h" +#include "config/aom_dsp_rtcd.h" +#include "test/acm_random.h" +#include "test/clear_system_state.h" +#include "test/register_state_check.h" +#include "test/transform_test_base.h" +#include "test/util.h" +#include "av1/common/entropy.h" +#include "aom/aom_codec.h" +#include "aom/aom_integer.h" +#include "aom_ports/mem.h" + +using libaom_test::ACMRandom; + +namespace { + +template <typename OutputType> +using FdctFunc = void (*)(const int16_t *in, OutputType *out, int stride); + +template <typename OutputType> +using FhtFunc = void (*)(const int16_t *in, OutputType *out, int stride, + TxfmParam *txfm_param); + +template <typename OutputType> +using Fdct4x4Param = ::testing::tuple<FdctFunc<OutputType>, FhtFunc<OutputType>, + aom_bit_depth_t, int>; + +#if HAVE_NEON || HAVE_SSE2 +void fdct4x4_ref(const int16_t *in, tran_low_t *out, int stride, + TxfmParam * /*txfm_param*/) { + aom_fdct4x4_c(in, out, stride); +} + +void fdct4x4_lp_ref(const int16_t *in, int16_t *out, int stride, + TxfmParam * /*txfm_param*/) { + aom_fdct4x4_lp_c(in, out, stride); +} +#endif + +template <typename OutputType> +class Trans4x4FDCT : public libaom_test::TransformTestBase<OutputType>, + public ::testing::TestWithParam<Fdct4x4Param<OutputType>> { + public: + virtual ~Trans4x4FDCT() {} + + using TxfmBaseOutType = libaom_test::TransformTestBase<OutputType>; + virtual void SetUp() { + fwd_txfm_ = ::testing::get<0>(this->GetParam()); + TxfmBaseOutType::pitch_ = 4; + TxfmBaseOutType::height_ = 4; + TxfmBaseOutType::fwd_txfm_ref = ::testing::get<1>(this->GetParam()); + TxfmBaseOutType::bit_depth_ = ::testing::get<2>(this->GetParam()); + TxfmBaseOutType::mask_ = (1 << TxfmBaseOutType::bit_depth_) - 1; + TxfmBaseOutType::num_coeffs_ = ::testing::get<3>(this->GetParam()); + } + virtual void TearDown() { libaom_test::ClearSystemState(); } + + protected: + void RunFwdTxfm(const int16_t *in, OutputType *out, int stride) { + fwd_txfm_(in, out, stride); + } + + void RunInvTxfm(const OutputType *out, uint8_t *dst, int stride) { + (void)out; + (void)dst; + (void)stride; + } + + FdctFunc<OutputType> fwd_txfm_; +}; + +using Trans4x4FDCTTranLow = Trans4x4FDCT<tran_low_t>; +TEST_P(Trans4x4FDCTTranLow, CoeffCheck) { RunCoeffCheck(); } +TEST_P(Trans4x4FDCTTranLow, MemCheck) { RunMemCheck(); } + +using Trans4x4FDCTInt16 = Trans4x4FDCT<int16_t>; +TEST_P(Trans4x4FDCTInt16, CoeffCheck) { RunCoeffCheck(); } +TEST_P(Trans4x4FDCTInt16, MemCheck) { RunMemCheck(); } + +using ::testing::make_tuple; + +#if HAVE_NEON +INSTANTIATE_TEST_CASE_P(NEON, Trans4x4FDCTTranLow, + ::testing::Values(make_tuple(&aom_fdct4x4_neon, + &fdct4x4_ref, AOM_BITS_8, + 16))); + +INSTANTIATE_TEST_CASE_P(NEON, Trans4x4FDCTInt16, + ::testing::Values(make_tuple(&aom_fdct4x4_lp_neon, + &fdct4x4_lp_ref, + AOM_BITS_8, 16))); +#endif + +#if HAVE_SSE2 +INSTANTIATE_TEST_CASE_P(SSE2, Trans4x4FDCTTranLow, + ::testing::Values(make_tuple(&aom_fdct4x4_sse2, + &fdct4x4_ref, AOM_BITS_8, + 16))); + +INSTANTIATE_TEST_CASE_P(SSE2, Trans4x4FDCTInt16, + ::testing::Values(make_tuple(&aom_fdct4x4_lp_sse2, + &fdct4x4_lp_ref, + AOM_BITS_8, 16))); +#endif +} // namespace
diff --git a/test/fwht4x4_test.cc b/test/fwht4x4_test.cc index c8d98c5..251acb7 100644 --- a/test/fwht4x4_test.cc +++ b/test/fwht4x4_test.cc
@@ -13,6 +13,7 @@ #include <stdlib.h> #include <string.h> +#include "aom_dsp/aom_dsp_common.h" #include "third_party/googletest/src/googletest/include/gtest/gtest.h" #include "config/av1_rtcd.h" @@ -51,7 +52,7 @@ av1_highbd_iwht4x4_16_add_c(in, out, stride, 12); } -class Trans4x4WHT : public libaom_test::TransformTestBase, +class Trans4x4WHT : public libaom_test::TransformTestBase<tran_low_t>, public ::testing::TestWithParam<Dct4x4Param> { public: virtual ~Trans4x4WHT() {}
diff --git a/test/test.cmake b/test/test.cmake index 5ae2ab6..b172e64 100644 --- a/test/test.cmake +++ b/test/test.cmake
@@ -207,6 +207,7 @@ "${AOM_ROOT}/test/error_block_test.cc" "${AOM_ROOT}/test/fft_test.cc" "${AOM_ROOT}/test/fwht4x4_test.cc" + "${AOM_ROOT}/test/fdct4x4_test.cc" "${AOM_ROOT}/test/hadamard_test.cc" "${AOM_ROOT}/test/horver_correlation_test.cc" "${AOM_ROOT}/test/masked_sad_test.cc"
diff --git a/test/transform_test_base.h b/test/transform_test_base.h index 8ebcf5f..68f5cc7 100644 --- a/test/transform_test_base.h +++ b/test/transform_test_base.h
@@ -29,20 +29,23 @@ // to a aom header file. const int kDctMaxValue = 16384; -typedef void (*FhtFunc)(const int16_t *in, tran_low_t *out, int stride, - TxfmParam *txfm_param); +template <typename OutputType> +using FhtFunc = void (*)(const int16_t *in, OutputType *out, int stride, + TxfmParam *txfm_param); -typedef void (*IhtFunc)(const tran_low_t *in, uint8_t *out, int stride, - const TxfmParam *txfm_param); +template <typename OutputType> +using IhtFunc = void (*)(const tran_low_t *in, uint8_t *out, int stride, + const TxfmParam *txfm_param); +template <typename OutType> class TransformTestBase { public: virtual ~TransformTestBase() {} protected: - virtual void RunFwdTxfm(const int16_t *in, tran_low_t *out, int stride) = 0; + virtual void RunFwdTxfm(const int16_t *in, OutType *out, int stride) = 0; - virtual void RunInvTxfm(const tran_low_t *out, uint8_t *dst, int stride) = 0; + virtual void RunInvTxfm(const OutType *out, uint8_t *dst, int stride) = 0; void RunAccuracyCheck(uint32_t ref_max_error, double ref_avg_error) { ACMRandom rnd(ACMRandom::DeterministicSeed()); @@ -52,8 +55,8 @@ int16_t *test_input_block = reinterpret_cast<int16_t *>( aom_memalign(16, sizeof(int16_t) * num_coeffs_)); - tran_low_t *test_temp_block = reinterpret_cast<tran_low_t *>( - aom_memalign(16, sizeof(tran_low_t) * num_coeffs_)); + OutType *test_temp_block = reinterpret_cast<OutType *>( + aom_memalign(16, sizeof(test_temp_block[0]) * num_coeffs_)); uint8_t *dst = reinterpret_cast<uint8_t *>( aom_memalign(16, sizeof(uint8_t) * num_coeffs_)); uint8_t *src = reinterpret_cast<uint8_t *>( @@ -123,10 +126,10 @@ int16_t *input_block = reinterpret_cast<int16_t *>( aom_memalign(16, sizeof(int16_t) * stride * height_)); - tran_low_t *output_ref_block = reinterpret_cast<tran_low_t *>( - aom_memalign(16, sizeof(tran_low_t) * num_coeffs_)); - tran_low_t *output_block = reinterpret_cast<tran_low_t *>( - aom_memalign(16, sizeof(tran_low_t) * num_coeffs_)); + OutType *output_ref_block = reinterpret_cast<OutType *>( + aom_memalign(16, sizeof(output_ref_block[0]) * num_coeffs_)); + OutType *output_block = reinterpret_cast<OutType *>( + aom_memalign(16, sizeof(output_block[0]) * num_coeffs_)); for (int i = 0; i < count_test_block; ++i) { int j, k; @@ -172,8 +175,8 @@ int16_t *input_block = reinterpret_cast<int16_t *>( aom_memalign(16, sizeof(int16_t) * num_coeffs_)); - tran_low_t *trans_block = reinterpret_cast<tran_low_t *>( - aom_memalign(16, sizeof(tran_low_t) * num_coeffs_)); + OutType *trans_block = reinterpret_cast<OutType *>( + aom_memalign(16, sizeof(trans_block[0]) * num_coeffs_)); uint8_t *output_block = reinterpret_cast<uint8_t *>( aom_memalign(16, sizeof(uint8_t) * stride * height_)); uint8_t *output_ref_block = reinterpret_cast<uint8_t *>( @@ -218,10 +221,10 @@ int16_t *input_extreme_block = reinterpret_cast<int16_t *>( aom_memalign(16, sizeof(int16_t) * num_coeffs_)); - tran_low_t *output_ref_block = reinterpret_cast<tran_low_t *>( - aom_memalign(16, sizeof(tran_low_t) * num_coeffs_)); - tran_low_t *output_block = reinterpret_cast<tran_low_t *>( - aom_memalign(16, sizeof(tran_low_t) * num_coeffs_)); + OutType *output_ref_block = reinterpret_cast<OutType *>( + aom_memalign(16, sizeof(output_ref_block[0]) * num_coeffs_)); + OutType *output_block = reinterpret_cast<OutType *>( + aom_memalign(16, sizeof(output_block[0]) * num_coeffs_)); for (int i = 0; i < count_test_block; ++i) { // Initialize a test block with input range [-mask_, mask_]. @@ -260,8 +263,8 @@ int16_t *in = reinterpret_cast<int16_t *>( aom_memalign(16, sizeof(int16_t) * num_coeffs_)); - tran_low_t *coeff = reinterpret_cast<tran_low_t *>( - aom_memalign(16, sizeof(tran_low_t) * num_coeffs_)); + OutType *coeff = reinterpret_cast<OutType *>( + aom_memalign(16, sizeof(coeff[0]) * num_coeffs_)); uint8_t *dst = reinterpret_cast<uint8_t *>( aom_memalign(16, sizeof(uint8_t) * num_coeffs_)); uint8_t *src = reinterpret_cast<uint8_t *>( @@ -313,8 +316,8 @@ int pitch_; int height_; - FhtFunc fwd_txfm_ref; - IhtFunc inv_txfm_ref; + FhtFunc<OutType> fwd_txfm_ref; + IhtFunc<OutType> inv_txfm_ref; aom_bit_depth_t bit_depth_; int mask_; int num_coeffs_;