Let rect txfms do shorter side first

Change-Id: I41a78f5066b96f59ab8f587bc5b1955f6806b103
diff --git a/av1/common/av1_fwd_txfm2d.c b/av1/common/av1_fwd_txfm2d.c
index c124f3a..1fced45 100644
--- a/av1/common/av1_fwd_txfm2d.c
+++ b/av1/common/av1_fwd_txfm2d.c
@@ -109,10 +109,27 @@
 
 void av1_fwd_txfm2d_4x8_c(const int16_t *input, int32_t *output, int stride,
                           int tx_type, int bd) {
+#if CONFIG_TXMG
+  (void)bd;
+  int32_t txfm_buf[4 * 8];
+  int16_t rinput[4 * 8];
+  int tx_size = TX_4X8;
+  int rtx_size = av1_rotate_tx_size(tx_size);
+  int rtx_type = av1_rotate_tx_type(tx_type);
+  int w = tx_size_wide[tx_size];
+  int h = tx_size_high[tx_size];
+  int rw = h;
+  int rh = w;
+  transpose_int16(rinput, rw, input, stride, w, h);
+  TXFM_2D_FLIP_CFG cfg = av1_get_fwd_txfm_cfg(rtx_type, rtx_size);
+  fwd_txfm2d_c(rinput, txfm_buf, rw, &cfg, output);
+  transpose_int32(output, w, txfm_buf, rw, rw, rh);
+#else
   int32_t txfm_buf[4 * 8];
   TXFM_2D_FLIP_CFG cfg = av1_get_fwd_txfm_cfg(tx_type, TX_4X8);
   (void)bd;
   fwd_txfm2d_c(input, output, stride, &cfg, txfm_buf);
+#endif
 }
 
 void av1_fwd_txfm2d_8x4_c(const int16_t *input, int32_t *output, int stride,
@@ -125,10 +142,27 @@
 
 void av1_fwd_txfm2d_8x16_c(const int16_t *input, int32_t *output, int stride,
                            int tx_type, int bd) {
+#if CONFIG_TXMG
+  (void)bd;
+  int32_t txfm_buf[8 * 16];
+  int16_t rinput[8 * 16];
+  int tx_size = TX_8X16;
+  int rtx_size = av1_rotate_tx_size(tx_size);
+  int rtx_type = av1_rotate_tx_type(tx_type);
+  int w = tx_size_wide[tx_size];
+  int h = tx_size_high[tx_size];
+  int rw = h;
+  int rh = w;
+  transpose_int16(rinput, rw, input, stride, w, h);
+  TXFM_2D_FLIP_CFG cfg = av1_get_fwd_txfm_cfg(rtx_type, rtx_size);
+  fwd_txfm2d_c(rinput, txfm_buf, rw, &cfg, output);
+  transpose_int32(output, w, txfm_buf, rw, rw, rh);
+#else
   int32_t txfm_buf[8 * 16];
   TXFM_2D_FLIP_CFG cfg = av1_get_fwd_txfm_cfg(tx_type, TX_8X16);
   (void)bd;
   fwd_txfm2d_c(input, output, stride, &cfg, txfm_buf);
+#endif
 }
 
 void av1_fwd_txfm2d_16x8_c(const int16_t *input, int32_t *output, int stride,
@@ -141,10 +175,27 @@
 
 void av1_fwd_txfm2d_16x32_c(const int16_t *input, int32_t *output, int stride,
                             int tx_type, int bd) {
+#if CONFIG_TXMG
+  (void)bd;
+  int32_t txfm_buf[16 * 32];
+  int16_t rinput[16 * 32];
+  int tx_size = TX_16X32;
+  int rtx_size = av1_rotate_tx_size(tx_size);
+  int rtx_type = av1_rotate_tx_type(tx_type);
+  int w = tx_size_wide[tx_size];
+  int h = tx_size_high[tx_size];
+  int rw = h;
+  int rh = w;
+  transpose_int16(rinput, rw, input, stride, w, h);
+  TXFM_2D_FLIP_CFG cfg = av1_get_fwd_txfm_cfg(rtx_type, rtx_size);
+  fwd_txfm2d_c(rinput, txfm_buf, rw, &cfg, output);
+  transpose_int32(output, w, txfm_buf, rw, rw, rh);
+#else
   int32_t txfm_buf[16 * 32];
   TXFM_2D_FLIP_CFG cfg = av1_get_fwd_txfm_cfg(tx_type, TX_16X32);
   (void)bd;
   fwd_txfm2d_c(input, output, stride, &cfg, txfm_buf);
+#endif
 }
 
 void av1_fwd_txfm2d_32x16_c(const int16_t *input, int32_t *output, int stride,
diff --git a/av1/common/av1_inv_txfm2d.c b/av1/common/av1_inv_txfm2d.c
index 58845f1..dfda0e8 100644
--- a/av1/common/av1_inv_txfm2d.c
+++ b/av1/common/av1_inv_txfm2d.c
@@ -211,8 +211,25 @@
 
 void av1_inv_txfm2d_add_8x4_c(const int32_t *input, uint16_t *output,
                               int stride, int tx_type, int bd) {
+#if CONFIG_TXMG
+  int txfm_buf[8 * 4 + 8 + 8];
+  int32_t rinput[8 * 4];
+  uint16_t routput[8 * 4];
+  int tx_size = TX_8X4;
+  int rtx_size = av1_rotate_tx_size(tx_size);
+  int rtx_type = av1_rotate_tx_type(tx_type);
+  int w = tx_size_wide[tx_size];
+  int h = tx_size_high[tx_size];
+  int rw = h;
+  int rh = w;
+  transpose_int32(rinput, rw, input, w, w, h);
+  transpose_uint16(routput, rw, output, stride, w, h);
+  inv_txfm2d_add_facade(rinput, routput, rw, txfm_buf, rtx_type, rtx_size, bd);
+  transpose_uint16(output, stride, routput, rw, rw, rh);
+#else
   int txfm_buf[8 * 4 + 4 + 4];
   inv_txfm2d_add_facade(input, output, stride, txfm_buf, tx_type, TX_8X4, bd);
+#endif
 }
 
 void av1_inv_txfm2d_add_8x16_c(const int32_t *input, uint16_t *output,
@@ -223,8 +240,25 @@
 
 void av1_inv_txfm2d_add_16x8_c(const int32_t *input, uint16_t *output,
                                int stride, int tx_type, int bd) {
+#if CONFIG_TXMG
+  int txfm_buf[16 * 8 + 16 + 16];
+  int32_t rinput[16 * 8];
+  uint16_t routput[16 * 8];
+  int tx_size = TX_16X8;
+  int rtx_size = av1_rotate_tx_size(tx_size);
+  int rtx_type = av1_rotate_tx_type(tx_type);
+  int w = tx_size_wide[tx_size];
+  int h = tx_size_high[tx_size];
+  int rw = h;
+  int rh = w;
+  transpose_int32(rinput, rw, input, w, w, h);
+  transpose_uint16(routput, rw, output, stride, w, h);
+  inv_txfm2d_add_facade(rinput, routput, rw, txfm_buf, rtx_type, rtx_size, bd);
+  transpose_uint16(output, stride, routput, rw, rw, rh);
+#else
   int txfm_buf[16 * 8 + 8 + 8];
   inv_txfm2d_add_facade(input, output, stride, txfm_buf, tx_type, TX_16X8, bd);
+#endif
 }
 
 void av1_inv_txfm2d_add_16x32_c(const int32_t *input, uint16_t *output,
@@ -235,8 +269,25 @@
 
 void av1_inv_txfm2d_add_32x16_c(const int32_t *input, uint16_t *output,
                                 int stride, int tx_type, int bd) {
+#if CONFIG_TXMG
+  int txfm_buf[32 * 16 + 32 + 32];
+  int32_t rinput[32 * 16];
+  uint16_t routput[32 * 16];
+  int tx_size = TX_32X16;
+  int rtx_size = av1_rotate_tx_size(tx_size);
+  int rtx_type = av1_rotate_tx_type(tx_type);
+  int w = tx_size_wide[tx_size];
+  int h = tx_size_high[tx_size];
+  int rw = h;
+  int rh = w;
+  transpose_int32(rinput, rw, input, w, w, h);
+  transpose_uint16(routput, rw, output, stride, w, h);
+  inv_txfm2d_add_facade(rinput, routput, rw, txfm_buf, rtx_type, rtx_size, bd);
+  transpose_uint16(output, stride, routput, rw, rw, rh);
+#else
   int txfm_buf[32 * 16 + 16 + 16];
   inv_txfm2d_add_facade(input, output, stride, txfm_buf, tx_type, TX_32X16, bd);
+#endif
 }
 
 void av1_inv_txfm2d_add_4x4_c(const int32_t *input, uint16_t *output,
diff --git a/av1/common/av1_txfm.h b/av1/common/av1_txfm.h
index aa4a76a..2219cf9 100644
--- a/av1/common/av1_txfm.h
+++ b/av1/common/av1_txfm.h
@@ -17,6 +17,7 @@
 #include <stdio.h>
 
 #include "av1/common/enums.h"
+#include "av1/common/blockd.h"
 #include "aom/aom_integer.h"
 #include "aom_dsp/aom_dsp_common.h"
 
@@ -209,6 +210,61 @@
   }
 }
 
+#if CONFIG_TXMG
+static INLINE int av1_rotate_tx_size(int tx_size) {
+  switch (tx_size) {
+#if CONFIG_CHROMA_2X2
+    case TX_2X2: return TX_2X2;
+#endif
+    case TX_4X4: return TX_4X4;
+    case TX_8X8: return TX_8X8;
+    case TX_16X16: return TX_16X16;
+    case TX_32X32: return TX_32X32;
+#if CONFIG_TX64X64
+    case TX_64X64: return TX_64X64;
+#endif
+    case TX_4X8: return TX_8X4;
+    case TX_8X4: return TX_4X8;
+    case TX_8X16: return TX_16X8;
+    case TX_16X8: return TX_8X16;
+    case TX_16X32: return TX_32X16;
+    case TX_32X16: return TX_16X32;
+    case TX_4X16: return TX_16X4;
+    case TX_16X4: return TX_4X16;
+    case TX_8X32: return TX_32X8;
+    case TX_32X8: return TX_8X32;
+    default: assert(0); return TX_INVALID;
+  }
+}
+
+static INLINE int av1_rotate_tx_type(int tx_type) {
+  switch (tx_type) {
+    case DCT_DCT: return DCT_DCT;
+    case ADST_DCT: return DCT_ADST;
+    case DCT_ADST: return ADST_DCT;
+    case ADST_ADST: return ADST_ADST;
+#if CONFIG_EXT_TX
+    case FLIPADST_DCT: return DCT_FLIPADST;
+    case DCT_FLIPADST: return FLIPADST_DCT;
+    case FLIPADST_FLIPADST: return FLIPADST_FLIPADST;
+    case ADST_FLIPADST: return FLIPADST_ADST;
+    case FLIPADST_ADST: return ADST_FLIPADST;
+    case IDTX: return IDTX;
+    case V_DCT: return H_DCT;
+    case H_DCT: return V_DCT;
+    case V_ADST: return H_ADST;
+    case H_ADST: return V_ADST;
+    case V_FLIPADST: return H_FLIPADST;
+    case H_FLIPADST: return V_FLIPADST;
+#endif  // CONFIG_EXT_TX
+#if CONFIG_MRC_TX
+    case MRC_DCT: return MRC_DCT;
+#endif  // CONFIG_MRC_TX
+    default: assert(0); return TX_TYPES;
+  }
+}
+#endif  // CONFIG_TXMG
+
 #if CONFIG_MRC_TX
 static INLINE int get_mrc_mask(const uint8_t *pred, int pred_stride, int *mask,
                                int mask_stride, int width, int height) {
diff --git a/av1/common/blockd.h b/av1/common/blockd.h
index 4f2b6ce..060f820 100644
--- a/av1/common/blockd.h
+++ b/av1/common/blockd.h
@@ -1743,6 +1743,38 @@
   return (plane == 0) ? PLANE_TYPE_Y : PLANE_TYPE_UV;
 }
 
+static INLINE void transpose_uint8(uint8_t *dst, int dst_stride,
+                                   const uint8_t *src, int src_stride, int w,
+                                   int h) {
+  int r, c;
+  for (r = 0; r < h; ++r)
+    for (c = 0; c < w; ++c) dst[c * dst_stride + r] = src[r * src_stride + c];
+}
+
+static INLINE void transpose_uint16(uint16_t *dst, int dst_stride,
+                                    const uint16_t *src, int src_stride, int w,
+                                    int h) {
+  int r, c;
+  for (r = 0; r < h; ++r)
+    for (c = 0; c < w; ++c) dst[c * dst_stride + r] = src[r * src_stride + c];
+}
+
+static INLINE void transpose_int16(int16_t *dst, int dst_stride,
+                                   const int16_t *src, int src_stride, int w,
+                                   int h) {
+  int r, c;
+  for (r = 0; r < h; ++r)
+    for (c = 0; c < w; ++c) dst[c * dst_stride + r] = src[r * src_stride + c];
+}
+
+static INLINE void transpose_int32(int32_t *dst, int dst_stride,
+                                   const int32_t *src, int src_stride, int w,
+                                   int h) {
+  int r, c;
+  for (r = 0; r < h; ++r)
+    for (c = 0; c < w; ++c) dst[c * dst_stride + r] = src[r * src_stride + c];
+}
+
 #ifdef __cplusplus
 }  // extern "C"
 #endif
diff --git a/av1/common/convolve.c b/av1/common/convolve.c
index 54ad755..fe31408 100644
--- a/av1/common/convolve.c
+++ b/av1/common/convolve.c
@@ -14,6 +14,7 @@
 
 #include "./aom_dsp_rtcd.h"
 #include "./av1_rtcd.h"
+#include "av1/common/blockd.h"
 #include "av1/common/convolve.h"
 #include "av1/common/filter.h"
 #include "av1/common/onyxc_int.h"
@@ -422,24 +423,6 @@
 }
 #endif
 
-static INLINE void transpose_uint8(uint8_t *dst, int dst_stride,
-                                   const uint8_t *src, int src_stride, int w,
-                                   int h) {
-  int r, c;
-  for (r = 0; r < h; ++r)
-    for (c = 0; c < w; ++c)
-      dst[c * (dst_stride) + r] = src[r * (src_stride) + c];
-}
-
-static INLINE void transpose_int32(int32_t *dst, int dst_stride,
-                                   const int32_t *src, int src_stride, int w,
-                                   int h) {
-  int r, c;
-  for (r = 0; r < h; ++r)
-    for (c = 0; c < w; ++c)
-      dst[c * (dst_stride) + r] = src[r * (src_stride) + c];
-}
-
 void av1_convolve_2d_facade(const uint8_t *src, int src_stride, uint8_t *dst,
                             int dst_stride, int w, int h,
                             const InterpFilter *interp_filter,
@@ -500,14 +483,6 @@
 }
 
 #if CONFIG_HIGHBITDEPTH
-static INLINE void transpose_uint16(uint16_t *dst, int dst_stride,
-                                    const uint16_t *src, int src_stride, int w,
-                                    int h) {
-  int r, c;
-  for (r = 0; r < h; ++r)
-    for (c = 0; c < w; ++c) dst[c * dst_stride + r] = src[r * src_stride + c];
-}
-
 void av1_highbd_convolve_rounding_c(const int32_t *src, int src_stride,
                                     uint8_t *dst8, int dst_stride, int w, int h,
                                     int bits, int bd) {