Fix aom_fdct32x32_avx2 output as CONFIG_AOM_HIGHBITDEPTH=1

- Change FDCT32x32_2D_AVX2 output parameter to tran_low_t.
- Add unit tests for CONFIG_AOM_HIGHBITDEPTH=1.
- Update TODO notes.
BUG=webm:1323

Change-Id: If4766c919a24231fce886de74658b6dd7a011246
diff --git a/aom_dsp/aom_dsp.mk b/aom_dsp/aom_dsp.mk
index 4735199..eebdc0c 100644
--- a/aom_dsp/aom_dsp.mk
+++ b/aom_dsp/aom_dsp.mk
@@ -191,6 +191,7 @@
 endif  # CONFIG_AOM_HIGHBITDEPTH
 
 DSP_SRCS-yes            += txfm_common.h
+DSP_SRCS-yes            += x86/txfm_common_intrin.h
 DSP_SRCS-$(HAVE_SSE2)   += x86/txfm_common_sse2.h
 DSP_SRCS-$(HAVE_MSA)    += mips/txfm_macros_msa.h
 # forward transform
diff --git a/aom_dsp/x86/fwd_dct32x32_impl_avx2.h b/aom_dsp/x86/fwd_dct32x32_impl_avx2.h
index 8b136e7..2167395 100644
--- a/aom_dsp/x86/fwd_dct32x32_impl_avx2.h
+++ b/aom_dsp/x86/fwd_dct32x32_impl_avx2.h
@@ -12,6 +12,7 @@
 #include <immintrin.h>  // AVX2
 
 #include "aom_dsp/txfm_common.h"
+#include "aom_dsp/x86/txfm_common_intrin.h"
 #include "aom_dsp/x86/txfm_common_avx2.h"
 
 #if FDCT32x32_HIGH_PRECISION
@@ -31,7 +32,19 @@
 }
 #endif
 
-void FDCT32x32_2D_AVX2(const int16_t *input, int16_t *output_org, int stride) {
+#ifndef STORE_COEFF_FUNC
+#define STORE_COEFF_FUNC
+static void store_coeff(const __m256i *coeff, tran_low_t *curr,
+                        tran_low_t *next) {
+  __m128i u = _mm256_castsi256_si128(*coeff);
+  storeu_output(&u, curr);
+  u = _mm256_extractf128_si256(*coeff, 1);
+  storeu_output(&u, next);
+}
+#endif
+
+void FDCT32x32_2D_AVX2(const int16_t *input, tran_low_t *output_org,
+                       int stride) {
   // Calculate pre-multiplied strides
   const int str1 = stride;
   const int str2 = 2 * stride;
@@ -2842,13 +2855,14 @@
       {
         int transpose_block;
         int16_t *output_currStep, *output_nextStep;
-        if (0 == pass) {
-          output_currStep = &intermediate[column_start * 32];
-          output_nextStep = &intermediate[(column_start + 8) * 32];
-        } else {
-          output_currStep = &output_org[column_start * 32];
-          output_nextStep = &output_org[(column_start + 8) * 32];
-        }
+        tran_low_t *curr_out, *next_out;
+        // Pass 0
+        output_currStep = &intermediate[column_start * 32];
+        output_nextStep = &intermediate[(column_start + 8) * 32];
+        // Pass 1
+        curr_out = &output_org[column_start * 32];
+        next_out = &output_org[(column_start + 8) * 32];
+
         for (transpose_block = 0; transpose_block < 4; ++transpose_block) {
           __m256i *this_out = &out[8 * transpose_block];
           // 00  01  02  03  04  05  06  07  08  09  10  11  12  13  14  15
@@ -2948,44 +2962,58 @@
             tr2_6 = _mm256_srai_epi16(tr2_6, 2);
             tr2_7 = _mm256_srai_epi16(tr2_7, 2);
           }
-          // Note: even though all these stores are aligned, using the aligned
-          //       intrinsic make the code slightly slower.
-          _mm_storeu_si128((__m128i *)(output_currStep + 0 * 32),
-                           _mm256_castsi256_si128(tr2_0));
-          _mm_storeu_si128((__m128i *)(output_currStep + 1 * 32),
-                           _mm256_castsi256_si128(tr2_1));
-          _mm_storeu_si128((__m128i *)(output_currStep + 2 * 32),
-                           _mm256_castsi256_si128(tr2_2));
-          _mm_storeu_si128((__m128i *)(output_currStep + 3 * 32),
-                           _mm256_castsi256_si128(tr2_3));
-          _mm_storeu_si128((__m128i *)(output_currStep + 4 * 32),
-                           _mm256_castsi256_si128(tr2_4));
-          _mm_storeu_si128((__m128i *)(output_currStep + 5 * 32),
-                           _mm256_castsi256_si128(tr2_5));
-          _mm_storeu_si128((__m128i *)(output_currStep + 6 * 32),
-                           _mm256_castsi256_si128(tr2_6));
-          _mm_storeu_si128((__m128i *)(output_currStep + 7 * 32),
-                           _mm256_castsi256_si128(tr2_7));
+          if (0 == pass) {
+            // Note: even though all these stores are aligned, using the aligned
+            //       intrinsic make the code slightly slower.
+            _mm_storeu_si128((__m128i *)(output_currStep + 0 * 32),
+                             _mm256_castsi256_si128(tr2_0));
+            _mm_storeu_si128((__m128i *)(output_currStep + 1 * 32),
+                             _mm256_castsi256_si128(tr2_1));
+            _mm_storeu_si128((__m128i *)(output_currStep + 2 * 32),
+                             _mm256_castsi256_si128(tr2_2));
+            _mm_storeu_si128((__m128i *)(output_currStep + 3 * 32),
+                             _mm256_castsi256_si128(tr2_3));
+            _mm_storeu_si128((__m128i *)(output_currStep + 4 * 32),
+                             _mm256_castsi256_si128(tr2_4));
+            _mm_storeu_si128((__m128i *)(output_currStep + 5 * 32),
+                             _mm256_castsi256_si128(tr2_5));
+            _mm_storeu_si128((__m128i *)(output_currStep + 6 * 32),
+                             _mm256_castsi256_si128(tr2_6));
+            _mm_storeu_si128((__m128i *)(output_currStep + 7 * 32),
+                             _mm256_castsi256_si128(tr2_7));
 
-          _mm_storeu_si128((__m128i *)(output_nextStep + 0 * 32),
-                           _mm256_extractf128_si256(tr2_0, 1));
-          _mm_storeu_si128((__m128i *)(output_nextStep + 1 * 32),
-                           _mm256_extractf128_si256(tr2_1, 1));
-          _mm_storeu_si128((__m128i *)(output_nextStep + 2 * 32),
-                           _mm256_extractf128_si256(tr2_2, 1));
-          _mm_storeu_si128((__m128i *)(output_nextStep + 3 * 32),
-                           _mm256_extractf128_si256(tr2_3, 1));
-          _mm_storeu_si128((__m128i *)(output_nextStep + 4 * 32),
-                           _mm256_extractf128_si256(tr2_4, 1));
-          _mm_storeu_si128((__m128i *)(output_nextStep + 5 * 32),
-                           _mm256_extractf128_si256(tr2_5, 1));
-          _mm_storeu_si128((__m128i *)(output_nextStep + 6 * 32),
-                           _mm256_extractf128_si256(tr2_6, 1));
-          _mm_storeu_si128((__m128i *)(output_nextStep + 7 * 32),
-                           _mm256_extractf128_si256(tr2_7, 1));
-          // Process next 8x8
-          output_currStep += 8;
-          output_nextStep += 8;
+            _mm_storeu_si128((__m128i *)(output_nextStep + 0 * 32),
+                             _mm256_extractf128_si256(tr2_0, 1));
+            _mm_storeu_si128((__m128i *)(output_nextStep + 1 * 32),
+                             _mm256_extractf128_si256(tr2_1, 1));
+            _mm_storeu_si128((__m128i *)(output_nextStep + 2 * 32),
+                             _mm256_extractf128_si256(tr2_2, 1));
+            _mm_storeu_si128((__m128i *)(output_nextStep + 3 * 32),
+                             _mm256_extractf128_si256(tr2_3, 1));
+            _mm_storeu_si128((__m128i *)(output_nextStep + 4 * 32),
+                             _mm256_extractf128_si256(tr2_4, 1));
+            _mm_storeu_si128((__m128i *)(output_nextStep + 5 * 32),
+                             _mm256_extractf128_si256(tr2_5, 1));
+            _mm_storeu_si128((__m128i *)(output_nextStep + 6 * 32),
+                             _mm256_extractf128_si256(tr2_6, 1));
+            _mm_storeu_si128((__m128i *)(output_nextStep + 7 * 32),
+                             _mm256_extractf128_si256(tr2_7, 1));
+            // Process next 8x8
+            output_currStep += 8;
+            output_nextStep += 8;
+          }
+          if (1 == pass) {
+            store_coeff(&tr2_0, curr_out + 0 * 32, next_out + 0 * 32);
+            store_coeff(&tr2_1, curr_out + 1 * 32, next_out + 1 * 32);
+            store_coeff(&tr2_2, curr_out + 2 * 32, next_out + 2 * 32);
+            store_coeff(&tr2_3, curr_out + 3 * 32, next_out + 3 * 32);
+            store_coeff(&tr2_4, curr_out + 4 * 32, next_out + 4 * 32);
+            store_coeff(&tr2_5, curr_out + 5 * 32, next_out + 5 * 32);
+            store_coeff(&tr2_6, curr_out + 6 * 32, next_out + 6 * 32);
+            store_coeff(&tr2_7, curr_out + 7 * 32, next_out + 7 * 32);
+            curr_out += 8;
+            next_out += 8;
+          }
         }
       }
     }
diff --git a/aom_dsp/x86/fwd_txfm_avx2.c b/aom_dsp/x86/fwd_txfm_avx2.c
index d381a6e..670f864 100644
--- a/aom_dsp/x86/fwd_txfm_avx2.c
+++ b/aom_dsp/x86/fwd_txfm_avx2.c
@@ -17,14 +17,6 @@
 #undef FDCT32x32_2D_AVX2
 #undef FDCT32x32_HIGH_PRECISION
 
-// TODO(luoyi): The following macro hides an error. The second parameter type of
-// function,
-//   void FDCT32x32_2D_AVX2(const int16_t *, int16_t*, int);
-// is different from the one in,
-//   void aom_fdct32x32_avx2(const int16_t *, tran_low_t*, int);
-// In CONFIG_AOM_HIGHBITDEPTH=1 build, the second parameter type should be
-// int32_t.
-// This function should be removed after av1_fht32x32 scaling/rounding fix.
 #define FDCT32x32_2D_AVX2 aom_fdct32x32_avx2
 #define FDCT32x32_HIGH_PRECISION 1
 #include "aom_dsp/x86/fwd_dct32x32_impl_avx2.h"  // NOLINT
diff --git a/aom_dsp/x86/fwd_txfm_sse2.h b/aom_dsp/x86/fwd_txfm_sse2.h
index 3261584..fe3e446 100644
--- a/aom_dsp/x86/fwd_txfm_sse2.h
+++ b/aom_dsp/x86/fwd_txfm_sse2.h
@@ -12,6 +12,8 @@
 #ifndef AOM_DSP_X86_FWD_TXFM_SSE2_H_
 #define AOM_DSP_X86_FWD_TXFM_SSE2_H_
 
+#include "aom_dsp/x86/txfm_common_intrin.h"
+
 #ifdef __cplusplus
 extern "C" {
 #endif
@@ -257,19 +259,6 @@
 #endif  // CONFIG_AOM_HIGHBITDEPTH
 }
 
-static INLINE void storeu_output(const __m128i *poutput, tran_low_t *dst_ptr) {
-#if CONFIG_AOM_HIGHBITDEPTH
-  const __m128i zero = _mm_setzero_si128();
-  const __m128i sign_bits = _mm_cmplt_epi16(*poutput, zero);
-  __m128i out0 = _mm_unpacklo_epi16(*poutput, sign_bits);
-  __m128i out1 = _mm_unpackhi_epi16(*poutput, sign_bits);
-  _mm_storeu_si128((__m128i *)(dst_ptr), out0);
-  _mm_storeu_si128((__m128i *)(dst_ptr + 4), out1);
-#else
-  _mm_storeu_si128((__m128i *)(dst_ptr), *poutput);
-#endif  // CONFIG_AOM_HIGHBITDEPTH
-}
-
 static INLINE __m128i mult_round_shift(const __m128i *pin0, const __m128i *pin1,
                                        const __m128i *pmultiplier,
                                        const __m128i *prounding,
diff --git a/aom_dsp/x86/txfm_common_intrin.h b/aom_dsp/x86/txfm_common_intrin.h
new file mode 100644
index 0000000..890e048
--- /dev/null
+++ b/aom_dsp/x86/txfm_common_intrin.h
@@ -0,0 +1,31 @@
+/*
+ * Copyright (c) 2016, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#ifndef _AOM_DSP_X86_TXFM_COMMON_INTRIN_H_
+#define _AOM_DSP_X86_TXFM_COMMON_INTRIN_H_
+
+// Note:
+//  This header file should be put below any x86 intrinsics head file
+
+static INLINE void storeu_output(const __m128i *poutput, tran_low_t *dst_ptr) {
+#if CONFIG_AOM_HIGHBITDEPTH
+  const __m128i zero = _mm_setzero_si128();
+  const __m128i sign_bits = _mm_cmplt_epi16(*poutput, zero);
+  __m128i out0 = _mm_unpacklo_epi16(*poutput, sign_bits);
+  __m128i out1 = _mm_unpackhi_epi16(*poutput, sign_bits);
+  _mm_storeu_si128((__m128i *)(dst_ptr), out0);
+  _mm_storeu_si128((__m128i *)(dst_ptr + 4), out1);
+#else
+  _mm_storeu_si128((__m128i *)(dst_ptr), *poutput);
+#endif  // CONFIG_AOM_HIGHBITDEPTH
+}
+
+#endif  // _AOM_DSP_X86_TXFM_COMMON_INTRIN_H_
diff --git a/test/dct32x32_test.cc b/test/dct32x32_test.cc
index e4179ef..cb2fbd5 100644
--- a/test/dct32x32_test.cc
+++ b/test/dct32x32_test.cc
@@ -436,6 +436,15 @@
                                  &aom_idct32x32_1024_add_sse2, 1, AOM_BITS_8)));
 #endif  // HAVE_AVX2 && !CONFIG_AOM_HIGHBITDEPTH && !CONFIG_EMULATE_HARDWARE
 
+#if HAVE_AVX2 && CONFIG_AOM_HIGHBITDEPTH && !CONFIG_EMULATE_HARDWARE
+INSTANTIATE_TEST_CASE_P(
+    AVX2, Trans32x32Test,
+    ::testing::Values(make_tuple(&aom_fdct32x32_avx2,
+                                 &aom_idct32x32_1024_add_sse2, 0, AOM_BITS_8),
+                      make_tuple(&aom_fdct32x32_rd_avx2,
+                                 &aom_idct32x32_1024_add_sse2, 1, AOM_BITS_8)));
+#endif  // HAVE_AVX2 && CONFIG_AOM_HIGHBITDEPTH && !CONFIG_EMULATE_HARDWARE
+
 #if HAVE_MSA && !CONFIG_AOM_HIGHBITDEPTH && !CONFIG_EMULATE_HARDWARE
 INSTANTIATE_TEST_CASE_P(
     MSA, Trans32x32Test,
diff --git a/test/fht32x32_test.cc b/test/fht32x32_test.cc
index 1f85761..8545b2c 100644
--- a/test/fht32x32_test.cc
+++ b/test/fht32x32_test.cc
@@ -90,12 +90,11 @@
   IhtFunc inv_txfm_;
 };
 
-// TODO(luoyi): Owing to the range check in DCT_DCT of av1_fht32x32_avx2, as
-// input is out of the range, we use aom_fdct32x32_avx2. However this function
-// does not support CONFIG_AOM_HIGHBITDEPTH. I need to fix the scaling/rounding
-// of av1_fht32x32_avx2 then add this test on CONFIG_AOM_HIGHBITDEPTH.
-#if !CONFIG_AOM_HIGHBITDEPTH
 TEST_P(AV1Trans32x32HT, CoeffCheck) { RunCoeffCheck(); }
+// TODO(luoyi): As CONFIG_AOM_HIGHBITDEPTH = 1, our AVX2 implementation of
+// av1_fht32x32 does not support tran_low_t (int32_t) as intermediate result.
+// Therefore MemCheck test, tx_type=1,2,...,8 can't pass the test yet.
+#if !CONFIG_AOM_HIGHBITDEPTH
 TEST_P(AV1Trans32x32HT, MemCheck) { RunMemCheck(); }
 #endif