Edit ext-tx so it isn't doing redundant prunes

The original pruning function was not taking into account
that certain tx sizes/block sizes use a reduced tx set.

Prune 1: -0.3% performance drop, 20% speedup on foreman video
Prune 2: -0.48% perfomance drop, 30% speedup on foreman video

Change-Id: I557e919d97a89f787b47b3c8579a080db57f91d0
diff --git a/vp10/encoder/rdopt.c b/vp10/encoder/rdopt.c
index d040e0b..f9b3f8d 100644
--- a/vp10/encoder/rdopt.c
+++ b/vp10/encoder/rdopt.c
@@ -360,8 +360,50 @@
 // constants for prune 1 and prune 2 decision boundaries
 #define FAST_EXT_TX_CORR_MID 0.0
 #define FAST_EXT_TX_EDST_MID 0.1
-#define FAST_EXT_TX_CORR_MARGIN 0.3
-#define FAST_EXT_TX_EDST_MARGIN 0.5
+#define FAST_EXT_TX_CORR_MARGIN 0.5
+#define FAST_EXT_TX_EDST_MARGIN 0.3
+
+static const TX_TYPE_1D vtx_tab[TX_TYPES] = {
+  DCT_1D,
+  ADST_1D,
+  DCT_1D,
+  ADST_1D,
+#if CONFIG_EXT_TX
+  FLIPADST_1D,
+  DCT_1D,
+  FLIPADST_1D,
+  ADST_1D,
+  FLIPADST_1D,
+  IDTX_1D,
+  DCT_1D,
+  IDTX_1D,
+  ADST_1D,
+  IDTX_1D,
+  FLIPADST_1D,
+  IDTX_1D,
+#endif  // CONFIG_EXT_TX
+};
+
+static const TX_TYPE_1D htx_tab[TX_TYPES] = {
+  DCT_1D,
+  DCT_1D,
+  ADST_1D,
+  ADST_1D,
+#if CONFIG_EXT_TX
+  DCT_1D,
+  FLIPADST_1D,
+  FLIPADST_1D,
+  FLIPADST_1D,
+  ADST_1D,
+  IDTX_1D,
+  IDTX_1D,
+  DCT_1D,
+  IDTX_1D,
+  ADST_1D,
+  IDTX_1D,
+  FLIPADST_1D,
+#endif  // CONFIG_EXT_TX
+};
 
 static void get_energy_distribution_fine(const VP10_COMP *cpi,
                                          BLOCK_SIZE bsize,
@@ -586,7 +628,8 @@
 static int prune_two_for_sby(const VP10_COMP *cpi,
                              BLOCK_SIZE bsize,
                              MACROBLOCK *x,
-                             MACROBLOCKD *xd) {
+                             MACROBLOCKD *xd, int adst_flipadst,
+                             int dct_idtx) {
   struct macroblock_plane *const p = &x->plane[0];
   struct macroblockd_plane *const pd = &xd->plane[0];
   const BLOCK_SIZE bs = get_plane_block_size(bsize, pd);
@@ -594,12 +637,17 @@
   const int bh = 4 << (b_height_log2_lookup[bs]);
   double hdist[3] = {0, 0, 0}, vdist[3] = {0, 0, 0};
   double hcorr, vcorr;
+  int prune = 0;
   vp10_subtract_plane(x, bsize, 0);
-  return adst_vs_flipadst(cpi, bsize, p->src.buf, p->src.stride, pd->dst.buf,
-                          pd->dst.stride, hdist, vdist) |
-         dct_vs_idtx(p->src_diff, bw, bw, bh, &hcorr, &vcorr);
-}
 
+  if (adst_flipadst)
+    prune |= adst_vs_flipadst(cpi, bsize, p->src.buf, p->src.stride,
+                              pd->dst.buf, pd->dst.stride, hdist, vdist);
+  if (dct_idtx)
+    prune |= dct_vs_idtx(p->src_diff, bw, bw, bh, &hcorr, &vcorr);
+
+  return prune;
+}
 #endif  // CONFIG_EXT_TX
 
 // Performance drop: 0.3%, Speed improvement: 5%
@@ -618,17 +666,32 @@
 static int prune_tx_types(const VP10_COMP *cpi,
                           BLOCK_SIZE bsize,
                           MACROBLOCK *x,
-                          MACROBLOCKD *xd) {
+                          MACROBLOCKD *xd, int tx_set) {
+#if CONFIG_EXT_TX
+  const int *tx_set_1D = ext_tx_used_inter_1D[tx_set];
+#else
+  const int tx_set_1D[TX_TYPES_1D] = {0};
+#endif
+
   switch (cpi->sf.tx_type_search) {
     case NO_PRUNE:
       return 0;
       break;
     case PRUNE_ONE :
+      if ((tx_set >= 0) & !(tx_set_1D[FLIPADST_1D] & tx_set_1D[ADST_1D]))
+        return 0;
       return prune_one_for_sby(cpi, bsize, x, xd);
       break;
   #if CONFIG_EXT_TX
     case PRUNE_TWO :
-      return prune_two_for_sby(cpi, bsize, x, xd);
+      if ((tx_set >= 0) & !(tx_set_1D[FLIPADST_1D] & tx_set_1D[ADST_1D])) {
+        if (!(tx_set_1D[DCT_1D] & tx_set_1D[IDTX_1D]))
+          return 0;
+        return prune_two_for_sby(cpi, bsize, x, xd, 0, 1);
+      }
+      if ((tx_set >= 0) & !(tx_set_1D[DCT_1D] & tx_set_1D[IDTX_1D]))
+        return prune_two_for_sby(cpi, bsize, x, xd, 1, 0);
+      return prune_two_for_sby(cpi, bsize, x, xd, 1, 1);
       break;
   #endif
   }
@@ -640,46 +703,12 @@
                              int prune) {
 // TODO(sarahparker) implement for non ext tx
 #if CONFIG_EXT_TX
-  static TX_TYPE_1D vtx_tab[TX_TYPES] = {
-    DCT_1D,
-    ADST_1D,
-    DCT_1D,
-    ADST_1D,
-    FLIPADST_1D,
-    DCT_1D,
-    FLIPADST_1D,
-    ADST_1D,
-    FLIPADST_1D,
-    IDTX_1D,
-    DCT_1D,
-    IDTX_1D,
-    ADST_1D,
-    IDTX_1D,
-    FLIPADST_1D,
-    IDTX_1D,
-  };
-  static TX_TYPE_1D htx_tab[TX_TYPES] = {
-    DCT_1D,
-    DCT_1D,
-    ADST_1D,
-    ADST_1D,
-    DCT_1D,
-    FLIPADST_1D,
-    FLIPADST_1D,
-    FLIPADST_1D,
-    ADST_1D,
-    IDTX_1D,
-    IDTX_1D,
-    DCT_1D,
-    IDTX_1D,
-    ADST_1D,
-    IDTX_1D,
-    FLIPADST_1D,
-  };
   return !(((prune >> vtx_tab[tx_type]) & 1) |
          ((prune >> (htx_tab[tx_type] + 8)) & 1));
 #else
   // temporary to avoid compiler warnings
+  (void) vtx_tab;
+  (void) htx_tab;
   (void) tx_type;
   (void) prune;
   return 1;
@@ -1517,13 +1546,19 @@
   int ext_tx_set;
 #endif  // CONFIG_EXT_TX
 
-  if (is_inter && cpi->sf.tx_type_search > 0)
-    prune = prune_tx_types(cpi, bs, x, xd);
   mbmi->tx_size = VPXMIN(max_tx_size, largest_tx_size);
 
 #if CONFIG_EXT_TX
   ext_tx_set = get_ext_tx_set(mbmi->tx_size, bs, is_inter);
+#endif  // CONFIG_EXT_TX
 
+  if (is_inter && cpi->sf.tx_type_search > 0)
+#if CONFIG_EXT_TX
+    prune = prune_tx_types(cpi, bs, x, xd, ext_tx_set);
+#else
+    prune = prune_tx_types(cpi, bs, x, xd, 0);
+#endif
+#if CONFIG_EXT_TX
   if (get_ext_tx_types(mbmi->tx_size, bs, is_inter) > 1 &&
       !xd->lossless[mbmi->segment_id]) {
     for (tx_type = 0; tx_type < TX_TYPES; ++tx_type) {
@@ -1661,7 +1696,9 @@
   int prune = 0;
 
   if (is_inter && cpi->sf.tx_type_search > 0)
-    prune = prune_tx_types(cpi, bs, x, xd);
+    // passing -1 in for tx_type indicates that all 1D
+    // transforms should be considered for pruning
+    prune = prune_tx_types(cpi, bs, x, xd, -1);
 
   *distortion = INT64_MAX;
   *rate       = INT_MAX;
@@ -3281,7 +3318,11 @@
 #endif  // CONFIG_EXT_TX
 
   if (is_inter && cpi->sf.tx_type_search > 0)
-    prune = prune_tx_types(cpi, bsize, x, xd);
+#if CONFIG_EXT_TX
+    prune = prune_tx_types(cpi, bsize, x, xd, ext_tx_set);
+#else
+    prune = prune_tx_types(cpi, bsize, x, xd, 0);
+#endif
 
   *distortion = INT64_MAX;
   *rate       = INT_MAX;