Integrate HBD inverse HT flip types sse4.1 optimization

- tx_size: 4x4, 8x8, 16x16.
- tx_type: FLIPADST_DCT, DCT_FLIPADST, FLIPADST_FLIPADST,
  ADST_FLIPADST, FLIPADST_ADST.
- Encoder speed improvement:
  park_joy_1080p_12: ~11%, crowd_run_1080p_12: ~7%.
- Add unit test cases for bit-exact against C.

Change-Id: Ia69d069031fa76c4625e845bfbfe7e6f6ed6e841
diff --git a/test/vp10_highbd_iht_test.cc b/test/vp10_highbd_iht_test.cc
index 0b7597d..caab04c 100644
--- a/test/vp10_highbd_iht_test.cc
+++ b/test/vp10_highbd_iht_test.cc
@@ -15,6 +15,7 @@
 #include "test/clear_system_state.h"
 #include "test/register_state_check.h"
 #include "test/util.h"
+#include "vp10/common/enums.h"
 #include "vpx_dsp/vpx_dsp_common.h"
 #include "vpx_ports/mem.h"
 
@@ -149,32 +150,68 @@
 
 const IHbdHtParam kArrayIhtParam[] = {
   // 16x16
-  make_tuple(PARAM_LIST_16X16, 0, 10),
-  make_tuple(PARAM_LIST_16X16, 0, 12),
-  make_tuple(PARAM_LIST_16X16, 1, 10),
-  make_tuple(PARAM_LIST_16X16, 1, 12),
-  make_tuple(PARAM_LIST_16X16, 2, 10),
-  make_tuple(PARAM_LIST_16X16, 2, 12),
-  make_tuple(PARAM_LIST_16X16, 3, 10),
-  make_tuple(PARAM_LIST_16X16, 3, 12),
+  make_tuple(PARAM_LIST_16X16, DCT_DCT, 10),
+  make_tuple(PARAM_LIST_16X16, DCT_DCT, 12),
+  make_tuple(PARAM_LIST_16X16, ADST_DCT, 10),
+  make_tuple(PARAM_LIST_16X16, ADST_DCT, 12),
+  make_tuple(PARAM_LIST_16X16, DCT_ADST, 10),
+  make_tuple(PARAM_LIST_16X16, DCT_ADST, 12),
+  make_tuple(PARAM_LIST_16X16, ADST_ADST, 10),
+  make_tuple(PARAM_LIST_16X16, ADST_ADST, 12),
+#if CONFIG_EXT_TX
+  make_tuple(PARAM_LIST_16X16, FLIPADST_DCT, 10),
+  make_tuple(PARAM_LIST_16X16, FLIPADST_DCT, 12),
+  make_tuple(PARAM_LIST_16X16, DCT_FLIPADST, 10),
+  make_tuple(PARAM_LIST_16X16, DCT_FLIPADST, 12),
+  make_tuple(PARAM_LIST_16X16, FLIPADST_FLIPADST, 10),
+  make_tuple(PARAM_LIST_16X16, FLIPADST_FLIPADST, 12),
+  make_tuple(PARAM_LIST_16X16, ADST_FLIPADST, 10),
+  make_tuple(PARAM_LIST_16X16, ADST_FLIPADST, 12),
+  make_tuple(PARAM_LIST_16X16, FLIPADST_ADST, 10),
+  make_tuple(PARAM_LIST_16X16, FLIPADST_ADST, 12),
+#endif
   // 8x8
-  make_tuple(PARAM_LIST_8X8, 0, 10),
-  make_tuple(PARAM_LIST_8X8, 0, 12),
-  make_tuple(PARAM_LIST_8X8, 1, 10),
-  make_tuple(PARAM_LIST_8X8, 1, 12),
-  make_tuple(PARAM_LIST_8X8, 2, 10),
-  make_tuple(PARAM_LIST_8X8, 2, 12),
-  make_tuple(PARAM_LIST_8X8, 3, 10),
-  make_tuple(PARAM_LIST_8X8, 3, 12),
+  make_tuple(PARAM_LIST_8X8, DCT_DCT, 10),
+  make_tuple(PARAM_LIST_8X8, DCT_DCT, 12),
+  make_tuple(PARAM_LIST_8X8, ADST_DCT, 10),
+  make_tuple(PARAM_LIST_8X8, ADST_DCT, 12),
+  make_tuple(PARAM_LIST_8X8, DCT_ADST, 10),
+  make_tuple(PARAM_LIST_8X8, DCT_ADST, 12),
+  make_tuple(PARAM_LIST_8X8, ADST_ADST, 10),
+  make_tuple(PARAM_LIST_8X8, ADST_ADST, 12),
+#if CONFIG_EXT_TX
+  make_tuple(PARAM_LIST_8X8, FLIPADST_DCT, 10),
+  make_tuple(PARAM_LIST_8X8, FLIPADST_DCT, 12),
+  make_tuple(PARAM_LIST_8X8, DCT_FLIPADST, 10),
+  make_tuple(PARAM_LIST_8X8, DCT_FLIPADST, 12),
+  make_tuple(PARAM_LIST_8X8, FLIPADST_FLIPADST, 10),
+  make_tuple(PARAM_LIST_8X8, FLIPADST_FLIPADST, 12),
+  make_tuple(PARAM_LIST_8X8, ADST_FLIPADST, 10),
+  make_tuple(PARAM_LIST_8X8, ADST_FLIPADST, 12),
+  make_tuple(PARAM_LIST_8X8, FLIPADST_ADST, 10),
+  make_tuple(PARAM_LIST_8X8, FLIPADST_ADST, 12),
+#endif
   // 4x4
-  make_tuple(PARAM_LIST_4X4, 0, 10),
-  make_tuple(PARAM_LIST_4X4, 0, 12),
-  make_tuple(PARAM_LIST_4X4, 1, 10),
-  make_tuple(PARAM_LIST_4X4, 1, 12),
-  make_tuple(PARAM_LIST_4X4, 2, 10),
-  make_tuple(PARAM_LIST_4X4, 2, 12),
-  make_tuple(PARAM_LIST_4X4, 3, 10),
-  make_tuple(PARAM_LIST_4X4, 3, 12),
+  make_tuple(PARAM_LIST_4X4, DCT_DCT, 10),
+  make_tuple(PARAM_LIST_4X4, DCT_DCT, 12),
+  make_tuple(PARAM_LIST_4X4, ADST_DCT, 10),
+  make_tuple(PARAM_LIST_4X4, ADST_DCT, 12),
+  make_tuple(PARAM_LIST_4X4, DCT_ADST, 10),
+  make_tuple(PARAM_LIST_4X4, DCT_ADST, 12),
+  make_tuple(PARAM_LIST_4X4, ADST_ADST, 10),
+  make_tuple(PARAM_LIST_4X4, ADST_ADST, 12),
+#if CONFIG_EXT_TX
+  make_tuple(PARAM_LIST_4X4, FLIPADST_DCT, 10),
+  make_tuple(PARAM_LIST_4X4, FLIPADST_DCT, 12),
+  make_tuple(PARAM_LIST_4X4, DCT_FLIPADST, 10),
+  make_tuple(PARAM_LIST_4X4, DCT_FLIPADST, 12),
+  make_tuple(PARAM_LIST_4X4, FLIPADST_FLIPADST, 10),
+  make_tuple(PARAM_LIST_4X4, FLIPADST_FLIPADST, 12),
+  make_tuple(PARAM_LIST_4X4, ADST_FLIPADST, 10),
+  make_tuple(PARAM_LIST_4X4, ADST_FLIPADST, 12),
+  make_tuple(PARAM_LIST_4X4, FLIPADST_ADST, 10),
+  make_tuple(PARAM_LIST_4X4, FLIPADST_ADST, 12),
+#endif
 };
 
 INSTANTIATE_TEST_CASE_P(
diff --git a/vp10/common/idct.c b/vp10/common/idct.c
index 717c914..179b903 100644
--- a/vp10/common/idct.c
+++ b/vp10/common/idct.c
@@ -1297,7 +1297,7 @@
     case FLIPADST_FLIPADST:
     case ADST_FLIPADST:
     case FLIPADST_ADST:
-      vp10_inv_txfm2d_add_4x4_c(input, CONVERT_TO_SHORTPTR(dest), stride,
+      vp10_inv_txfm2d_add_4x4(input, CONVERT_TO_SHORTPTR(dest), stride,
                               tx_type, bd);
       break;
     case V_DCT:
@@ -1337,7 +1337,7 @@
     case FLIPADST_FLIPADST:
     case ADST_FLIPADST:
     case FLIPADST_ADST:
-      vp10_inv_txfm2d_add_8x8_c(input, CONVERT_TO_SHORTPTR(dest), stride,
+      vp10_inv_txfm2d_add_8x8(input, CONVERT_TO_SHORTPTR(dest), stride,
                               tx_type, bd);
       break;
     case V_DCT:
@@ -1377,7 +1377,7 @@
     case FLIPADST_FLIPADST:
     case ADST_FLIPADST:
     case FLIPADST_ADST:
-      vp10_inv_txfm2d_add_16x16_c(input, CONVERT_TO_SHORTPTR(dest), stride,
+      vp10_inv_txfm2d_add_16x16(input, CONVERT_TO_SHORTPTR(dest), stride,
                                 tx_type, bd);
       break;
     case V_DCT:
diff --git a/vp10/common/x86/highbd_inv_txfm_sse4.c b/vp10/common/x86/highbd_inv_txfm_sse4.c
index 9ece108..349aec5 100644
--- a/vp10/common/x86/highbd_inv_txfm_sse4.c
+++ b/vp10/common/x86/highbd_inv_txfm_sse4.c
@@ -176,7 +176,7 @@
 }
 
 static void write_buffer_4x4(__m128i *in, uint16_t *output, int stride,
-                             int shift, int bd) {
+                             int fliplr, int flipud, int shift, int bd) {
   const __m128i zero = _mm_setzero_si128();
   __m128i u0, u1, u2, u3;
   __m128i v0, v1, v2, v3;
@@ -193,10 +193,24 @@
   v2 = _mm_unpacklo_epi16(v2, zero);
   v3 = _mm_unpacklo_epi16(v3, zero);
 
-  u0 = _mm_add_epi32(in[0], v0);
-  u1 = _mm_add_epi32(in[1], v1);
-  u2 = _mm_add_epi32(in[2], v2);
-  u3 = _mm_add_epi32(in[3], v3);
+  if (fliplr) {
+    in[0] = _mm_shuffle_epi32(in[0], 0x1B);
+    in[1] = _mm_shuffle_epi32(in[1], 0x1B);
+    in[2] = _mm_shuffle_epi32(in[2], 0x1B);
+    in[3] = _mm_shuffle_epi32(in[3], 0x1B);
+  }
+
+  if (flipud) {
+    u0 = _mm_add_epi32(in[3], v0);
+    u1 = _mm_add_epi32(in[2], v1);
+    u2 = _mm_add_epi32(in[1], v2);
+    u3 = _mm_add_epi32(in[0], v3);
+  } else {
+    u0 = _mm_add_epi32(in[0], v0);
+    u1 = _mm_add_epi32(in[1], v1);
+    u2 = _mm_add_epi32(in[2], v2);
+    u3 = _mm_add_epi32(in[3], v3);
+  }
 
   v0 = _mm_packus_epi32(u0, u1);
   v2 = _mm_packus_epi32(u2, u3);
@@ -226,29 +240,66 @@
       load_buffer_4x4(coeff, in);
       idct4x4_sse4_1(in, cfg->cos_bit_row[2]);
       idct4x4_sse4_1(in, cfg->cos_bit_col[2]);
-      write_buffer_4x4(in, output, stride, -cfg->shift[1], bd);
+      write_buffer_4x4(in, output, stride, 0, 0, -cfg->shift[1], bd);
       break;
     case ADST_DCT:
       cfg = &inv_txfm_2d_cfg_adst_dct_4;
       load_buffer_4x4(coeff, in);
       idct4x4_sse4_1(in, cfg->cos_bit_row[2]);
       iadst4x4_sse4_1(in, cfg->cos_bit_col[2]);
-      write_buffer_4x4(in, output, stride, -cfg->shift[1], bd);
+      write_buffer_4x4(in, output, stride, 0, 0, -cfg->shift[1], bd);
       break;
     case DCT_ADST:
       cfg = &inv_txfm_2d_cfg_dct_adst_4;
       load_buffer_4x4(coeff, in);
       iadst4x4_sse4_1(in, cfg->cos_bit_row[2]);
       idct4x4_sse4_1(in, cfg->cos_bit_col[2]);
-      write_buffer_4x4(in, output, stride, -cfg->shift[1], bd);
+      write_buffer_4x4(in, output, stride, 0, 0, -cfg->shift[1], bd);
       break;
     case ADST_ADST:
       cfg = &inv_txfm_2d_cfg_adst_adst_4;
       load_buffer_4x4(coeff, in);
       iadst4x4_sse4_1(in, cfg->cos_bit_row[2]);
       iadst4x4_sse4_1(in, cfg->cos_bit_col[2]);
-      write_buffer_4x4(in, output, stride, -cfg->shift[1], bd);
+      write_buffer_4x4(in, output, stride, 0, 0, -cfg->shift[1], bd);
       break;
+#if CONFIG_EXT_TX
+    case FLIPADST_DCT:
+      cfg = &inv_txfm_2d_cfg_adst_dct_4;
+      load_buffer_4x4(coeff, in);
+      idct4x4_sse4_1(in, cfg->cos_bit_row[2]);
+      iadst4x4_sse4_1(in, cfg->cos_bit_col[2]);
+      write_buffer_4x4(in, output, stride, 0, 1, -cfg->shift[1], bd);
+      break;
+    case DCT_FLIPADST:
+      cfg = &inv_txfm_2d_cfg_dct_adst_4;
+      load_buffer_4x4(coeff, in);
+      iadst4x4_sse4_1(in, cfg->cos_bit_row[2]);
+      idct4x4_sse4_1(in, cfg->cos_bit_col[2]);
+      write_buffer_4x4(in, output, stride, 1, 0, -cfg->shift[1], bd);
+      break;
+    case FLIPADST_FLIPADST:
+      cfg = &inv_txfm_2d_cfg_adst_adst_4;
+      load_buffer_4x4(coeff, in);
+      iadst4x4_sse4_1(in, cfg->cos_bit_row[2]);
+      iadst4x4_sse4_1(in, cfg->cos_bit_col[2]);
+      write_buffer_4x4(in, output, stride, 1, 1, -cfg->shift[1], bd);
+      break;
+    case ADST_FLIPADST:
+      cfg = &inv_txfm_2d_cfg_adst_adst_4;
+      load_buffer_4x4(coeff, in);
+      iadst4x4_sse4_1(in, cfg->cos_bit_row[2]);
+      iadst4x4_sse4_1(in, cfg->cos_bit_col[2]);
+      write_buffer_4x4(in, output, stride, 1, 0, -cfg->shift[1], bd);
+      break;
+    case FLIPADST_ADST:
+      cfg = &inv_txfm_2d_cfg_adst_adst_4;
+      load_buffer_4x4(coeff, in);
+      iadst4x4_sse4_1(in, cfg->cos_bit_row[2]);
+      iadst4x4_sse4_1(in, cfg->cos_bit_col[2]);
+      write_buffer_4x4(in, output, stride, 0, 1, -cfg->shift[1], bd);
+      break;
+#endif  // CONFIG_EXT_TX
     default:
       assert(0);
   }
@@ -576,12 +627,33 @@
   round_shift_4x4(&in[12], shift);
 }
 
-static void write_buffer_8x8(__m128i *in, uint16_t *output, int stride,
-                             int shift, int bd) {
+static __m128i get_recon_8x8(const __m128i pred, __m128i res_lo,
+                             __m128i res_hi, int fliplr, int bd) {
+  __m128i x0, x1;
   const __m128i zero = _mm_setzero_si128();
+
+  x0 = _mm_unpacklo_epi16(pred, zero);
+  x1 = _mm_unpackhi_epi16(pred, zero);
+
+  if (fliplr) {
+    res_lo = _mm_shuffle_epi32(res_lo, 0x1B);
+    res_hi = _mm_shuffle_epi32(res_hi, 0x1B);
+    x0 = _mm_add_epi32(res_hi, x0);
+    x1 = _mm_add_epi32(res_lo, x1);
+
+  } else {
+    x0 = _mm_add_epi32(res_lo, x0);
+    x1 = _mm_add_epi32(res_hi, x1);
+  }
+
+  x0 = _mm_packus_epi32(x0, x1);
+  return highbd_clamp_epi16(x0, bd);
+}
+
+static void write_buffer_8x8(__m128i *in, uint16_t *output, int stride,
+                             int fliplr, int flipud, int shift, int bd) {
   __m128i u0, u1, u2, u3, u4, u5, u6, u7;
   __m128i v0, v1, v2, v3, v4, v5, v6, v7;
-  __m128i x0, x1;
 
   round_shift_8x8(in, shift);
 
@@ -594,61 +666,25 @@
   v6 = _mm_load_si128((__m128i const *)(output + 6 * stride));
   v7 = _mm_load_si128((__m128i const *)(output + 7 * stride));
 
-  x0 = _mm_unpacklo_epi16(v0, zero);
-  x1 = _mm_unpackhi_epi16(v0, zero);
-  x0 = _mm_add_epi32(in[0], x0);
-  x1 = _mm_add_epi32(in[1], x1);
-  x0 = _mm_packus_epi32(x0, x1);
-  u0 = highbd_clamp_epi16(x0, bd);
-
-  x0 = _mm_unpacklo_epi16(v1, zero);
-  x1 = _mm_unpackhi_epi16(v1, zero);
-  x0 = _mm_add_epi32(in[2], x0);
-  x1 = _mm_add_epi32(in[3], x1);
-  x0 = _mm_packus_epi32(x0, x1);
-  u1 = highbd_clamp_epi16(x0, bd);
-
-  x0 = _mm_unpacklo_epi16(v2, zero);
-  x1 = _mm_unpackhi_epi16(v2, zero);
-  x0 = _mm_add_epi32(in[4], x0);
-  x1 = _mm_add_epi32(in[5], x1);
-  x0 = _mm_packus_epi32(x0, x1);
-  u2 = highbd_clamp_epi16(x0, bd);
-
-  x0 = _mm_unpacklo_epi16(v3, zero);
-  x1 = _mm_unpackhi_epi16(v3, zero);
-  x0 = _mm_add_epi32(in[6], x0);
-  x1 = _mm_add_epi32(in[7], x1);
-  x0 = _mm_packus_epi32(x0, x1);
-  u3 = highbd_clamp_epi16(x0, bd);
-
-  x0 = _mm_unpacklo_epi16(v4, zero);
-  x1 = _mm_unpackhi_epi16(v4, zero);
-  x0 = _mm_add_epi32(in[8], x0);
-  x1 = _mm_add_epi32(in[9], x1);
-  x0 = _mm_packus_epi32(x0, x1);
-  u4 = highbd_clamp_epi16(x0, bd);
-
-  x0 = _mm_unpacklo_epi16(v5, zero);
-  x1 = _mm_unpackhi_epi16(v5, zero);
-  x0 = _mm_add_epi32(in[10], x0);
-  x1 = _mm_add_epi32(in[11], x1);
-  x0 = _mm_packus_epi32(x0, x1);
-  u5 = highbd_clamp_epi16(x0, bd);
-
-  x0 = _mm_unpacklo_epi16(v6, zero);
-  x1 = _mm_unpackhi_epi16(v6, zero);
-  x0 = _mm_add_epi32(in[12], x0);
-  x1 = _mm_add_epi32(in[13], x1);
-  x0 = _mm_packus_epi32(x0, x1);
-  u6 = highbd_clamp_epi16(x0, bd);
-
-  x0 = _mm_unpacklo_epi16(v7, zero);
-  x1 = _mm_unpackhi_epi16(v7, zero);
-  x0 = _mm_add_epi32(in[14], x0);
-  x1 = _mm_add_epi32(in[15], x1);
-  x0 = _mm_packus_epi32(x0, x1);
-  u7 = highbd_clamp_epi16(x0, bd);
+  if (flipud) {
+    u0 = get_recon_8x8(v0, in[14], in[15], fliplr, bd);
+    u1 = get_recon_8x8(v1, in[12], in[13], fliplr, bd);
+    u2 = get_recon_8x8(v2, in[10], in[11], fliplr, bd);
+    u3 = get_recon_8x8(v3, in[8], in[9], fliplr, bd);
+    u4 = get_recon_8x8(v4, in[6], in[7], fliplr, bd);
+    u5 = get_recon_8x8(v5, in[4], in[5], fliplr, bd);
+    u6 = get_recon_8x8(v6, in[2], in[3], fliplr, bd);
+    u7 = get_recon_8x8(v7, in[0], in[1], fliplr, bd);
+  } else {
+    u0 = get_recon_8x8(v0, in[0], in[1], fliplr, bd);
+    u1 = get_recon_8x8(v1, in[2], in[3], fliplr, bd);
+    u2 = get_recon_8x8(v2, in[4], in[5], fliplr, bd);
+    u3 = get_recon_8x8(v3, in[6], in[7], fliplr, bd);
+    u4 = get_recon_8x8(v4, in[8], in[9], fliplr, bd);
+    u5 = get_recon_8x8(v5, in[10], in[11], fliplr, bd);
+    u6 = get_recon_8x8(v6, in[12], in[13], fliplr, bd);
+    u7 = get_recon_8x8(v7, in[14], in[15], fliplr, bd);
+  }
 
   _mm_store_si128((__m128i *)(output + 0 * stride), u0);
   _mm_store_si128((__m128i *)(output + 1 * stride), u1);
@@ -673,7 +709,7 @@
       idct8x8_sse4_1(out, in, cfg->cos_bit_row[2]);
       transpose_8x8(in, out);
       idct8x8_sse4_1(out, in, cfg->cos_bit_col[2]);
-      write_buffer_8x8(in, output, stride, -cfg->shift[1], bd);
+      write_buffer_8x8(in, output, stride, 0, 0, -cfg->shift[1], bd);
       break;
     case DCT_ADST:
       cfg = &inv_txfm_2d_cfg_dct_adst_8;
@@ -682,7 +718,7 @@
       iadst8x8_sse4_1(out, in, cfg->cos_bit_row[2]);
       transpose_8x8(in, out);
       idct8x8_sse4_1(out, in, cfg->cos_bit_col[2]);
-      write_buffer_8x8(in, output, stride, -cfg->shift[1], bd);
+      write_buffer_8x8(in, output, stride, 0, 0, -cfg->shift[1], bd);
       break;
     case ADST_DCT:
       cfg = &inv_txfm_2d_cfg_adst_dct_8;
@@ -691,7 +727,7 @@
       idct8x8_sse4_1(out, in, cfg->cos_bit_row[2]);
       transpose_8x8(in, out);
       iadst8x8_sse4_1(out, in, cfg->cos_bit_col[2]);
-      write_buffer_8x8(in, output, stride, -cfg->shift[1], bd);
+      write_buffer_8x8(in, output, stride, 0, 0, -cfg->shift[1], bd);
       break;
     case ADST_ADST:
       cfg = &inv_txfm_2d_cfg_adst_adst_8;
@@ -700,8 +736,55 @@
       iadst8x8_sse4_1(out, in, cfg->cos_bit_row[2]);
       transpose_8x8(in, out);
       iadst8x8_sse4_1(out, in, cfg->cos_bit_col[2]);
-      write_buffer_8x8(in, output, stride, -cfg->shift[1], bd);
+      write_buffer_8x8(in, output, stride, 0, 0, -cfg->shift[1], bd);
       break;
+#if CONFIG_EXT_TX
+    case FLIPADST_DCT:
+      cfg = &inv_txfm_2d_cfg_adst_dct_8;
+      load_buffer_8x8(coeff, in);
+      transpose_8x8(in, out);
+      idct8x8_sse4_1(out, in, cfg->cos_bit_row[2]);
+      transpose_8x8(in, out);
+      iadst8x8_sse4_1(out, in, cfg->cos_bit_col[2]);
+      write_buffer_8x8(in, output, stride, 0, 1, -cfg->shift[1], bd);
+      break;
+    case DCT_FLIPADST:
+      cfg = &inv_txfm_2d_cfg_dct_adst_8;
+      load_buffer_8x8(coeff, in);
+      transpose_8x8(in, out);
+      iadst8x8_sse4_1(out, in, cfg->cos_bit_row[2]);
+      transpose_8x8(in, out);
+      idct8x8_sse4_1(out, in, cfg->cos_bit_col[2]);
+      write_buffer_8x8(in, output, stride, 1, 0, -cfg->shift[1], bd);
+      break;
+    case ADST_FLIPADST:
+      cfg = &inv_txfm_2d_cfg_adst_adst_8;
+      load_buffer_8x8(coeff, in);
+      transpose_8x8(in, out);
+      iadst8x8_sse4_1(out, in, cfg->cos_bit_row[2]);
+      transpose_8x8(in, out);
+      iadst8x8_sse4_1(out, in, cfg->cos_bit_col[2]);
+      write_buffer_8x8(in, output, stride, 1, 0, -cfg->shift[1], bd);
+      break;
+    case FLIPADST_FLIPADST:
+      cfg = &inv_txfm_2d_cfg_adst_adst_8;
+      load_buffer_8x8(coeff, in);
+      transpose_8x8(in, out);
+      iadst8x8_sse4_1(out, in, cfg->cos_bit_row[2]);
+      transpose_8x8(in, out);
+      iadst8x8_sse4_1(out, in, cfg->cos_bit_col[2]);
+      write_buffer_8x8(in, output, stride, 1, 1, -cfg->shift[1], bd);
+      break;
+    case FLIPADST_ADST:
+      cfg = &inv_txfm_2d_cfg_adst_adst_8;
+      load_buffer_8x8(coeff, in);
+      transpose_8x8(in, out);
+      iadst8x8_sse4_1(out, in, cfg->cos_bit_row[2]);
+      transpose_8x8(in, out);
+      iadst8x8_sse4_1(out, in, cfg->cos_bit_col[2]);
+      write_buffer_8x8(in, output, stride, 0, 1, -cfg->shift[1], bd);
+      break;
+#endif  // CONFIG_EXT_TX
     default:
       assert(0);
   }
@@ -725,25 +808,46 @@
   }
 }
 
+static void swap_addr(uint16_t **output1, uint16_t **output2) {
+  uint16_t *tmp;
+  tmp = *output1;
+  *output1 = *output2;
+  *output2 = tmp;
+}
+
 static void write_buffer_16x16(__m128i *in, uint16_t *output, int stride,
-                               int shift, int bd) {
+                               int fliplr, int flipud, int shift, int bd) {
   __m128i in8x8[16];
+  uint16_t *leftUp = &output[0];
+  uint16_t *rightUp = &output[8];
+  uint16_t *leftDown = &output[8 * stride];
+  uint16_t *rightDown = &output[8 * stride + 8];
+
+  if (fliplr) {
+    swap_addr(&leftUp, &rightUp);
+    swap_addr(&leftDown, &rightDown);
+  }
+
+  if (flipud) {
+    swap_addr(&leftUp, &leftDown);
+    swap_addr(&rightUp, &rightDown);
+  }
 
   // Left-up quarter
   assign_8x8_input_from_16x16(in, in8x8, 0);
-  write_buffer_8x8(in8x8, &output[0], stride, shift, bd);
+  write_buffer_8x8(in8x8, leftUp, stride, fliplr, flipud, shift, bd);
 
   // Right-up quarter
   assign_8x8_input_from_16x16(in, in8x8, 2);
-  write_buffer_8x8(in8x8, &output[8], stride, shift, bd);
+  write_buffer_8x8(in8x8, rightUp, stride, fliplr, flipud, shift, bd);
 
   // Left-down quarter
   assign_8x8_input_from_16x16(in, in8x8, 32);
-  write_buffer_8x8(in8x8, &output[8 * stride], stride, shift, bd);
+  write_buffer_8x8(in8x8, leftDown, stride, fliplr, flipud, shift, bd);
 
   // Right-down quarter
   assign_8x8_input_from_16x16(in, in8x8, 34);
-  write_buffer_8x8(in8x8, &output[8 * stride + 8], stride, shift, bd);
+  write_buffer_8x8(in8x8, rightDown, stride, fliplr, flipud, shift, bd);
 }
 
 static void idct16x16_sse4_1(__m128i *in, __m128i *out, int bit) {
@@ -1207,7 +1311,7 @@
       round_shift_16x16(in, -cfg->shift[0]);
       transpose_16x16(in, out);
       idct16x16_sse4_1(out, in, cfg->cos_bit_col[2]);
-      write_buffer_16x16(in, output, stride, -cfg->shift[1], bd);
+      write_buffer_16x16(in, output, stride, 0, 0, -cfg->shift[1], bd);
       break;
     case DCT_ADST:
       cfg = &inv_txfm_2d_cfg_dct_adst_16;
@@ -1217,7 +1321,7 @@
       round_shift_16x16(in, -cfg->shift[0]);
       transpose_16x16(in, out);
       idct16x16_sse4_1(out, in, cfg->cos_bit_col[2]);
-      write_buffer_16x16(in, output, stride, -cfg->shift[1], bd);
+      write_buffer_16x16(in, output, stride, 0, 0, -cfg->shift[1], bd);
       break;
     case ADST_DCT:
       cfg = &inv_txfm_2d_cfg_adst_dct_16;
@@ -1227,7 +1331,7 @@
       round_shift_16x16(in, -cfg->shift[0]);
       transpose_16x16(in, out);
       iadst16x16_sse4_1(out, in, cfg->cos_bit_col[2]);
-      write_buffer_16x16(in, output, stride, -cfg->shift[1], bd);
+      write_buffer_16x16(in, output, stride, 0, 0, -cfg->shift[1], bd);
       break;
     case ADST_ADST:
       cfg = &inv_txfm_2d_cfg_adst_adst_16;
@@ -1237,8 +1341,60 @@
       round_shift_16x16(in, -cfg->shift[0]);
       transpose_16x16(in, out);
       iadst16x16_sse4_1(out, in, cfg->cos_bit_col[2]);
-      write_buffer_16x16(in, output, stride, -cfg->shift[1], bd);
+      write_buffer_16x16(in, output, stride, 0, 0, -cfg->shift[1], bd);
       break;
+#if CONFIG_EXT_TX
+    case FLIPADST_DCT:
+      cfg = &inv_txfm_2d_cfg_adst_dct_16;
+      load_buffer_16x16(coeff, in);
+      transpose_16x16(in, out);
+      idct16x16_sse4_1(out, in, cfg->cos_bit_row[2]);
+      round_shift_16x16(in, -cfg->shift[0]);
+      transpose_16x16(in, out);
+      iadst16x16_sse4_1(out, in, cfg->cos_bit_col[2]);
+      write_buffer_16x16(in, output, stride, 0, 1, -cfg->shift[1], bd);
+      break;
+    case DCT_FLIPADST:
+      cfg = &inv_txfm_2d_cfg_dct_adst_16;
+      load_buffer_16x16(coeff, in);
+      transpose_16x16(in, out);
+      iadst16x16_sse4_1(out, in, cfg->cos_bit_row[2]);
+      round_shift_16x16(in, -cfg->shift[0]);
+      transpose_16x16(in, out);
+      idct16x16_sse4_1(out, in, cfg->cos_bit_col[2]);
+      write_buffer_16x16(in, output, stride, 1, 0, -cfg->shift[1], bd);
+      break;
+    case ADST_FLIPADST:
+      cfg = &inv_txfm_2d_cfg_adst_adst_16;
+      load_buffer_16x16(coeff, in);
+      transpose_16x16(in, out);
+      iadst16x16_sse4_1(out, in, cfg->cos_bit_row[2]);
+      round_shift_16x16(in, -cfg->shift[0]);
+      transpose_16x16(in, out);
+      iadst16x16_sse4_1(out, in, cfg->cos_bit_col[2]);
+      write_buffer_16x16(in, output, stride, 1, 0, -cfg->shift[1], bd);
+      break;
+    case FLIPADST_FLIPADST:
+      cfg = &inv_txfm_2d_cfg_adst_adst_16;
+      load_buffer_16x16(coeff, in);
+      transpose_16x16(in, out);
+      iadst16x16_sse4_1(out, in, cfg->cos_bit_row[2]);
+      round_shift_16x16(in, -cfg->shift[0]);
+      transpose_16x16(in, out);
+      iadst16x16_sse4_1(out, in, cfg->cos_bit_col[2]);
+      write_buffer_16x16(in, output, stride, 1, 1, -cfg->shift[1], bd);
+      break;
+    case FLIPADST_ADST:
+      cfg = &inv_txfm_2d_cfg_adst_adst_16;
+      load_buffer_16x16(coeff, in);
+      transpose_16x16(in, out);
+      iadst16x16_sse4_1(out, in, cfg->cos_bit_row[2]);
+      round_shift_16x16(in, -cfg->shift[0]);
+      transpose_16x16(in, out);
+      iadst16x16_sse4_1(out, in, cfg->cos_bit_col[2]);
+      write_buffer_16x16(in, output, stride, 0, 1, -cfg->shift[1], bd);
+      break;
+#endif
     default:
       assert(0);
   }