daala_tx: Add SIMD version of 16-point DST/FlipDST

Change-Id: I626a5995e8adee34b2c4c979ed0b79fa8704e683
diff --git a/av1/common/x86/daala_inv_txfm_avx2.c b/av1/common/x86/daala_inv_txfm_avx2.c
index 7903b82..3771a80 100644
--- a/av1/common/x86/daala_inv_txfm_avx2.c
+++ b/av1/common/x86/daala_inv_txfm_avx2.c
@@ -1391,6 +1391,38 @@
                            od_idct16_kernel8_epi16, od_idct16_kernel16_epi16);
 }
 
+static void od_row_idst16_avx2(int16_t *out, int rows, const tran_low_t *in) {
+  od_row_tx16_avx2(out, rows, in,
+#if CONFIG_RECT_TX_EXT
+                   od_idst16_kernel8_epi16,
+#endif
+                   od_idst16_kernel8_epi32);
+}
+
+static void od_col_idst16_add_hbd_avx2(unsigned char *output_pixels,
+                                       int output_stride, int cols,
+                                       const int16_t *in, int bd) {
+  od_col_tx16_add_hbd_avx2(output_pixels, output_stride, cols, in, bd,
+                           od_idst16_kernel8_epi16, od_idst16_kernel16_epi16);
+}
+
+static void od_row_flip_idst16_avx2(int16_t *out, int rows,
+                                    const tran_low_t *in) {
+  od_row_tx16_avx2(out, rows, in,
+#if CONFIG_RECT_TX_EXT
+                   od_flip_idst16_kernel8_epi16,
+#endif
+                   od_flip_idst16_kernel8_epi32);
+}
+
+static void od_col_flip_idst16_add_hbd_avx2(unsigned char *output_pixels,
+                                            int output_stride, int cols,
+                                            const int16_t *in, int bd) {
+  od_col_tx16_add_hbd_avx2(output_pixels, output_stride, cols, in, bd,
+                           od_flip_idst16_kernel8_epi16,
+                           od_flip_idst16_kernel16_epi16);
+}
+
 typedef void (*daala_row_itx)(int16_t *out, int rows, const tran_low_t *in);
 typedef void (*daala_col_itx_add)(unsigned char *output_pixels,
                                   int output_stride, int cols,
@@ -1404,7 +1436,7 @@
   { od_row_idct8_avx2, od_row_idst8_avx2, od_row_flip_idst8_avx2,
     od_row_iidtx8_avx2 },
   // 16-point transforms
-  { od_row_idct16_avx2, NULL, NULL, NULL },
+  { od_row_idct16_avx2, od_row_idst16_avx2, od_row_flip_idst16_avx2, NULL },
   // 32-point transforms
   { NULL, NULL, NULL, NULL },
 #if CONFIG_TX64X64
@@ -1438,7 +1470,8 @@
       { od_col_idct8_add_hbd_avx2, od_col_idst8_add_hbd_avx2,
         od_col_flip_idst8_add_hbd_avx2, od_col_iidtx8_add_hbd_avx2 },
       // 16-point transforms
-      { od_col_idct16_add_hbd_avx2, NULL, NULL, NULL },
+      { od_col_idct16_add_hbd_avx2, od_col_idst16_add_hbd_avx2,
+        od_col_flip_idst16_add_hbd_avx2, NULL },
       // 32-point transforms
       { NULL, NULL, NULL, NULL },
 #if CONFIG_TX64X64
diff --git a/av1/common/x86/daala_tx_kernels.h b/av1/common/x86/daala_tx_kernels.h
index f7b51fc..19f620f 100644
--- a/av1/common/x86/daala_tx_kernels.h
+++ b/av1/common/x86/daala_tx_kernels.h
@@ -36,11 +36,14 @@
 
 static INLINE void OD_KERNEL_FUNC(od_rotate_add)(OD_REG *q0, OD_REG *q1, int c0,
                                                  int r0, int c1, int r1, int c2,
-                                                 int r2, int s) {
+                                                 int r2, int s, int avg) {
   OD_REG t_;
   OD_REG u_;
 
-  t_ = OD_ADD(*q0, *q1);
+  if (avg)
+    t_ = OD_AVG(*q0, *q1);
+  else
+    t_ = OD_ADD(*q0, *q1);
   u_ = OD_MUL(*q1, c0, r0);
   *q1 = OD_MUL(*q0, c1, r1);
   t_ = OD_MUL(t_, c2, r2);
@@ -86,6 +89,24 @@
   *q1 = OD_ADD(*q1, t_);
 }
 
+static INLINE void OD_KERNEL_FUNC(od_rotate_sub2)(OD_REG *q0, OD_REG *q1,
+                                                  int c0, int r0, int c1,
+                                                  int r1, int c2, int r2,
+                                                  int avg) {
+  OD_REG t_;
+  OD_REG u_;
+
+  if (avg)
+    t_ = OD_HRSUB(*q1, *q0);
+  else
+    t_ = OD_SUB(*q1, *q0);
+  u_ = OD_MUL(*q1, c0, r0);
+  *q1 = OD_MUL(*q0, c1, r1);
+  t_ = OD_MUL(t_, c2, r2);
+  *q0 = OD_SUB(t_, u_);
+  *q1 = OD_SUB(*q1, t_);
+}
+
 static INLINE void OD_KERNEL_FUNC(od_rotate_subh)(OD_REG *q0, OD_REG *q1,
                                                   OD_REG *q1h, int c0, int r0,
                                                   int c1, int r1, int c2,
@@ -126,6 +147,16 @@
   *q1 = OD_SUB(*q0, *q1);
 }
 
+static INLINE void OD_KERNEL_FUNC(od_butterfly_add2)(OD_REG *q0, OD_REG *q1) {
+  *q0 = OD_ADD(*q0, OD_RSHIFT1(*q1));
+  *q1 = OD_SUB(*q1, *q0);
+}
+
+static INLINE void OD_KERNEL_FUNC(od_butterfly_sub2)(OD_REG *q0, OD_REG *q1) {
+  *q0 = OD_SUB(*q0, OD_RSHIFT1(*q1));
+  *q1 = OD_ADD(*q1, *q0);
+}
+
 static INLINE void OD_KERNEL_FUNC(od_butterfly_addh)(OD_REG *q0, OD_REG *q1,
                                                      OD_REG *q1h) {
   *q0 = OD_ADD(*q0, *q1h);
@@ -349,7 +380,7 @@
   /* 17911/16384 ~= Sin[15*Pi/32] + Cos[15*Pi/32] ~= 1.0932018670017576 */
   /* 14699/16384 ~= Sin[15*Pi/32] - Cos[15*Pi/32] ~= 0.8971675863426363 */
   /* 803/8192 ~= Cos[15*Pi/32] ~= 0.0980171403295606 */
-  OD_KERNEL_FUNC(od_rotate_add)(r7, r0, 17911, 14, 14699, 14, 803, 13, 0);
+  OD_KERNEL_FUNC(od_rotate_add)(r7, r0, 17911, 14, 14699, 14, 803, 13, 0, 0);
   /* 40869/32768 ~= Sin[13*Pi/32] + Cos[13*Pi/32] ~= 1.247225012986671 */
   /* 21845/32768 ~= Sin[13*Pi/32] - Cos[13*Pi/32] ~= 0.6666556584777465 */
   /* 1189/4096 ~= Cos[13*Pi/32] ~= 0.29028467725446233 */
@@ -357,7 +388,7 @@
   /* 22173/16384 ~= Sin[11*Pi/32] + Cos[11*Pi/32] ~= 1.3533180011743526 */
   /* 3363/8192 ~= Sin[11*Pi/32] - Cos[11*Pi/32] ~= 0.4105245275223574 */
   /* 15447/32768 ~= Cos[11*Pi/32] ~= 0.47139673682599764 */
-  OD_KERNEL_FUNC(od_rotate_add)(r5, r2, 22173, 14, 3363, 13, 15447, 15, 0);
+  OD_KERNEL_FUNC(od_rotate_add)(r5, r2, 22173, 14, 3363, 13, 15447, 15, 0, 0);
   /* 23059/16384 ~= Sin[9*Pi/32] + Cos[9*Pi/32] ~= 1.4074037375263826 */
   /* 2271/16384 ~= Sin[9*Pi/32] - Cos[9*Pi/32] ~= 0.1386171691990915 */
   /* 5197/8192 ~= Cos[9*Pi/32] ~= 0.6343932841636455 */
@@ -386,7 +417,7 @@
   /* 12665/16384 ~= (Sin[15*Pi/32] + Cos[15*Pi/32])/Sqrt[2] ~= 0.77301045336 */
   /* 5197/4096 ~= (Sin[15*Pi/32] - Cos[15*Pi/32])*Sqrt[2] ~= 1.2687865683273 */
   /* 2271/16384 ~= Cos[15*Pi/32]*Sqrt[2] ~= 0.13861716919909148 */
-  OD_KERNEL_FUNC(od_rotate_add)(r7, r0, 12665, 14, 5197, 12, 2271, 14, 1);
+  OD_KERNEL_FUNC(od_rotate_add)(r7, r0, 12665, 14, 5197, 12, 2271, 14, 1, 0);
   /* 28899/32768 ~= (Sin[13*Pi/32] + Cos[13*Pi/32])/Sqrt[2] ~= 0.88192126435 */
   /* 30893/32768 ~= (Sin[13*Pi/32] - Cos[13*Pi/32])*Sqrt[2] ~= 0.94279347365 */
   /* 3363/8192 ~= Cos[13*Pi/32]*Sqrt[2] ~= 0.41052452752235735 */
@@ -394,7 +425,7 @@
   /* 31357/32768 ~= (Sin[11*Pi/32] + Cos[11*Pi/32])/Sqrt[2] ~= 0.95694033573 */
   /* 1189/2048 ~= (Sin[11*Pi/32] - Cos[11*Pi/32])*Sqrt[2] ~= 0.5805693545089 */
   /* 21845/32768 ~= Cos[11*Pi/32] ~= 0.6666556584777465 */
-  OD_KERNEL_FUNC(od_rotate_add)(r5, r2, 31357, 15, 1189, 11, 21845, 15, 1);
+  OD_KERNEL_FUNC(od_rotate_add)(r5, r2, 31357, 15, 1189, 11, 21845, 15, 1, 0);
   /* 16305/16384 ~= (Sin[9*Pi/32] + Cos[9*Pi/32])/Sqrt[2] ~= 0.9951847266722 */
   /* 803/4096 ~= (Sin[9*Pi/32] - Cos[9*Pi/32])*Sqrt[2] ~= 0.1960342806591213 */
   /* 14699/16384 ~= Cos[9*Pi/32]*Sqrt[2] ~= 0.8971675863426364 */
@@ -434,3 +465,127 @@
   OD_KERNEL_FUNC(od_butterfly_addh)(se, s1, &s1h);
   OD_KERNEL_FUNC(od_butterfly_add)(s0, sf);
 }
+
+static INLINE void OD_KERNEL_FUNC(od_idst16)(OD_REG *s0, OD_REG *s1, OD_REG *s2,
+                                             OD_REG *s3, OD_REG *s4, OD_REG *s5,
+                                             OD_REG *s6, OD_REG *s7, OD_REG *s8,
+                                             OD_REG *s9, OD_REG *sa, OD_REG *sb,
+                                             OD_REG *sc, OD_REG *sd, OD_REG *se,
+                                             OD_REG *sf) {
+  OD_REG s0h;
+  OD_REG s1h;
+  OD_REG s2h;
+  OD_REG s3h;
+  OD_REG s4h;
+  OD_REG s5h;
+  OD_REG s6h;
+  OD_REG s7h;
+  OD_REG sbh;
+  OD_REG sfh;
+  OD_REG h;
+  OD_KERNEL_FUNC(od_rotate45)(s9, s6, 1);
+  OD_KERNEL_FUNC(od_rotate45)(sa, s5, 1);
+  OD_KERNEL_FUNC(od_rotate45)(s8, s7, 1);
+  OD_KERNEL_FUNC(od_idst2)(s3, sc, 0);
+  OD_KERNEL_FUNC(od_idst2)(sb, s4, 1);
+  OD_KERNEL_FUNC(od_butterfly_v3)(s2, sa, &h);
+  OD_KERNEL_FUNC(od_butterfly_v2)(sd, s5, &h);
+  OD_KERNEL_FUNC(od_butterfly_v2)(s9, s1, &h);
+  OD_KERNEL_FUNC(od_butterfly_v3)(s6, se, &h);
+  OD_KERNEL_FUNC(od_butterfly_v3)(sc, sb, &sbh);
+  OD_KERNEL_FUNC(od_butterfly_v3)(s7, sf, &sfh);
+  OD_KERNEL_FUNC(od_butterfly_v2)(s8, s0, &s0h);
+  OD_KERNEL_FUNC(od_butterfly_v3)(s3, s4, &s4h);
+  /* 38531/32768 ~= Sin[7*Pi/16] + Cos[7*Pi/16] = 1.1758756024193586 */
+  /* 12873/16384 ~= Sin[7*Pi/16] - Cos[7*Pi/16] = 0.7856949583871022 */
+  /* 6393/32768 ~= Cos[7*Pi/16] = 0.19509032201612825 */
+  OD_KERNEL_FUNC(od_rotate_sub2)
+  (s2, sd, 38531, 15, 12873, 14, 6393, 15, 0);
+  /* 22725/16384 ~= Sin[5*Pi/16] + Cos[5*Pi/16] ~= 1.3870398453221475 */
+  /* 9041/32768 ~= Sin[5*Pi/16] - Cos[5*Pi/16] ~= 0.27589937928294306 */
+  /* 18205/16384 ~= 2*Cos[5*Pi/16] ~= 1.1111404660392044 */
+  OD_KERNEL_FUNC(od_rotate_sub2)
+  (sa, s5, 22725, 14, 9041, 15, 18205, 14, 1);
+  /* 45451/32768 ~= Sin[5*Pi/16] + Cos[5*Pi/16] ~= 1.3870398453221475 */
+  /* 9041/32768 ~= Sin[5*Pi/16] - Cos[5*Pi/16] ~= 0.27589937928294306 */
+  /* 18205/32768 ~= Cos[5*Pi/16] ~= 0.5555702330196022 */
+  OD_KERNEL_FUNC(od_rotate_add)
+  (s6, s9, 45451, 15, 9041, 15, 18205, 15, 0, 0);
+  /* 9633/8192 ~= Sin[7*Pi/16] + Cos[7*Pi/16] ~= 1.1758756024193586 */
+  /* 12873/16384 ~= Sin[7*Pi/16] - Cos[7*Pi/16] ~= 0.7856949583871022 */
+  /* 12785/32768 ~= 2*Cos[7*Pi/16] ~= 0.3901806440322565 */
+  OD_KERNEL_FUNC(od_rotate_add)
+  (se, s1, 9633, 13, 12873, 14, 12785, 15, 0, 1);
+  OD_KERNEL_FUNC(od_butterfly_subh)(s8, s4, &s4h);
+  OD_KERNEL_FUNC(od_butterfly_addh)(s7, sb, &sbh);
+  OD_KERNEL_FUNC(od_butterfly_subh)(s3, sf, &sfh);
+  OD_KERNEL_FUNC(od_butterfly_addh)(sc, s0, &s0h);
+  OD_KERNEL_FUNC(od_butterfly_add2)(sd, se);
+  OD_KERNEL_FUNC(od_butterfly_add2)(s2, s1);
+  OD_KERNEL_FUNC(od_butterfly_sub2)(s6, s5);
+  OD_KERNEL_FUNC(od_butterfly_sub2)(s9, sa);
+  OD_KERNEL_FUNC(od_butterfly_v2)(se, s0, &s0h);
+  OD_KERNEL_FUNC(od_butterfly_v2)(sf, s1, &s1h);
+  OD_KERNEL_FUNC(od_butterfly_v3)(sc, s2, &s2h);
+  OD_KERNEL_FUNC(od_butterfly_v3)(sd, s3, &s3h);
+  OD_KERNEL_FUNC(od_butterfly_v2)(sa, s4, &s4h);
+  OD_KERNEL_FUNC(od_butterfly_v2)(sb, s5, &s5h);
+  OD_KERNEL_FUNC(od_butterfly_v3)(s8, s6, &s6h);
+  OD_KERNEL_FUNC(od_butterfly_v3)(s9, s7, &s7h);
+  /* 32729/32768 ~= (Sin[17*Pi/64] + Cos[17*Pi/64])/Sqrt[2] ~= 0.99879545620 */
+  /* 201/2048 ~= (Sin[17*Pi/64] - Cos[17*Pi/64])*Sqrt[2] ~= 0.09813534865484 */
+  /* 31121/32768 ~= Cos[17*Pi/64]*Sqrt[2] ~= 0.9497277818777543 */
+  OD_KERNEL_FUNC(od_rotate_subh)
+  (se, s1, &s1h, 32729, 15, 201, 11, 31121, 15, 0);
+  /* 32413/32768 ~= (Sin[19*Pi/64] + Cos[19*Pi/64])/Sqrt[2] ~= 0.98917650996 */
+  /* 601/2048 ~= (Sin[19*Pi/64] - Cos[19*Pi/64])*Sqrt[2] ~= 0.29346094891072 */
+  /* 27605/32768 ~= Cos[19*Pi/64]*Sqrt[2] ~= 0.8424460355094193 */
+  OD_KERNEL_FUNC(od_rotate_addh)
+  (s9, s6, &s6h, 32413, 15, 601, 11, 27605, 15, 0);
+  /* 15893/16384 ~= (Sin[21*Pi/64] + Cos[21*Pi/64])/Sqrt[2] ~= 0.97003125319 */
+  /* 3981/8192 ~= (Sin[21*Pi/64] - Cos[21*Pi/64])*Sqrt[2] ~= 0.4859603598065 */
+  /* 1489/2048 ~= Cos[21*Pi/64]*Sqrt[2] ~= 0.72705107329128 */
+  OD_KERNEL_FUNC(od_rotate_subh)
+  (sa, s5, &s5h, 15893, 14, 3981, 13, 1489, 11, 0);
+  /* 30853/32768 ~= (Sin[23*Pi/64] + Cos[23*Pi/64])/Sqrt[2] ~= 0.94154406518 */
+  /* 11039/16384 ~= (Sin[23*Pi/64] - Cos[23*Pi/64])*Sqrt[2] ~= 0.67377970678 */
+  /* 19813/32768 ~= Cos[23*Pi/64]*Sqrt[2] ~= 0.6046542117908008 */
+  OD_KERNEL_FUNC(od_rotate_addh)
+  (sd, s2, &s2h, 30853, 15, 11039, 14, 19813, 15, 0);
+  /* 14811/16384 ~= (Sin[25*Pi/64] + Cos[25*Pi/64])/Sqrt[2] ~= 0.90398929312 */
+  /* 7005/8192 ~= (Sin[25*Pi/64] - Cos[25*Pi/64])*Sqrt[2] ~= 0.8551101868606 */
+  /* 3903/8192 ~= Cos[25*Pi/64]*Sqrt[2] ~= 0.47643419969316125 */
+  OD_KERNEL_FUNC(od_rotate_subh)
+  (sc, s3, &s3h, 14811, 14, 7005, 13, 3903, 13, 0);
+  /* 14053/16384 ~= (Sin[27*Pi/64] + Cos[27*Pi/64])/Sqrt[2] ~= 0.85772861000 */
+  /* 8423/8192 ~= (Sin[27*Pi/64] - Cos[27*Pi/64])*Sqrt[2] ~= 1.0282054883864 */
+  /* 2815/8192 ~= Cos[27*Pi/64]*Sqrt[2] ~= 0.34362586580705035 */
+  OD_KERNEL_FUNC(od_rotate_addh)
+  (sb, s4, &s4h, 14053, 14, 8423, 13, 2815, 13, 0);
+  /* 1645/2048 ~= (Sin[29*Pi/64] + Cos[29*Pi/64])/Sqrt[2] ~= 0.8032075314806 */
+  /* 305/256 ~= (Sin[29*Pi/64] - Cos[29*Pi/64])*Sqrt[2] ~= 1.191398608984867 */
+  /* 425/2048 ~= Cos[29*Pi/64]*Sqrt[2] ~= 0.20750822698821159 */
+  OD_KERNEL_FUNC(od_rotate_subh)
+  (s8, s7, &s7h, 1645, 11, 305, 8, 425, 11, 0);
+  /* 24279/32768 ~= (Sin[31*Pi/64] + Cos[31*Pi/64])/Sqrt[2] ~= 0.74095112535 */
+  /* 44011/32768 ~= (Sin[31*Pi/64] - Cos[31*Pi/64])*Sqrt[2] ~= 1.34311790969 */
+  /* 1137/16384 ~= Cos[31*Pi/64]*Sqrt[2] ~= 0.06939217050794069 */
+  OD_KERNEL_FUNC(od_rotate_addh)
+  (sf, s0, &s0h, 24279, 15, 44011, 15, 1137, 14, 0);
+}
+
+static INLINE void OD_KERNEL_FUNC(od_flip_idst16)(
+    OD_REG *s0, OD_REG *s1, OD_REG *s2, OD_REG *s3, OD_REG *s4, OD_REG *s5,
+    OD_REG *s6, OD_REG *s7, OD_REG *s8, OD_REG *s9, OD_REG *sa, OD_REG *sb,
+    OD_REG *sc, OD_REG *sd, OD_REG *se, OD_REG *sf) {
+  OD_KERNEL_FUNC(od_idst16)
+  (s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, sa, sb, sc, sd, se, sf);
+  OD_SWAP(s0, sf);
+  OD_SWAP(s1, se);
+  OD_SWAP(s2, sd);
+  OD_SWAP(s3, sc);
+  OD_SWAP(s4, sb);
+  OD_SWAP(s5, sa);
+  OD_SWAP(s6, s9);
+  OD_SWAP(s7, s8);
+}