Prune tx type search using previous stats

Record tx type usage and using the stats to prune following frame's
tx type search. Borg test result for speed 4:
         avg_psnr:   ovr_psnr:   ssim:
midres:   0.120        0.107    -0.034
Average speedup over whole midres set is 3.2%.
Will work on extending this feature to other speeds.

Change-Id: I864e1e4515c0f707aa48a3c0f9ed67b573603739
diff --git a/av1/encoder/encode_strategy.c b/av1/encoder/encode_strategy.c
index b6def7a..2e56077 100644
--- a/av1/encoder/encode_strategy.c
+++ b/av1/encoder/encode_strategy.c
@@ -182,23 +182,6 @@
   update_gf_group_index(cpi);
 }
 
-// Get update type of the current frame.
-static INLINE FRAME_UPDATE_TYPE get_frame_update_type(const AV1_COMP *cpi) {
-  const GF_GROUP *const gf_group = &cpi->gf_group;
-  if (gf_group->size == 0) {
-    // Special case 1: happens at the first frame of a video.
-    return KF_UPDATE;
-  }
-  if (gf_group->index == gf_group->size) {
-    // Special case 2: happens at the start of next GF group, or at the end of
-    // the key-frame group. So, not marked in gf_group->update_type array, but
-    // can be inferred implicitly.
-    return cpi->rc.source_alt_ref_active ? OVERLAY_UPDATE : GF_UPDATE;
-  }
-  // General case.
-  return gf_group->update_type[gf_group->index];
-}
-
 static void set_ext_overrides(AV1_COMP *const cpi,
                               EncodeFrameParams *const frame_params) {
   // Overrides the defaults with the externally supplied values with
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index 75a09bf..7e0c395 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -4613,6 +4613,7 @@
 
   av1_zero(*td->counts);
   av1_zero(rdc->comp_pred_diff);
+  av1_zero(rdc->tx_type_used);
 
   // Reset the flag.
   cpi->intrabc_used = 0;
@@ -4987,6 +4988,29 @@
   if (cm->delta_q_info.delta_q_present_flag && cpi->deltaq_used == 0) {
     cm->delta_q_info.delta_q_present_flag = 0;
   }
+
+  if (cpi->sf.tx_type_search.prune_tx_type_using_stats) {
+    const FRAME_UPDATE_TYPE update_type = get_frame_update_type(cpi);
+
+    for (i = 0; i < TX_SIZES_ALL; i++) {
+      int sum = 0;
+      int j;
+      int left = 1024;
+
+      for (j = 0; j < TX_TYPES; j++)
+        sum += cpi->td.rd_counts.tx_type_used[update_type][i][j];
+
+      for (j = TX_TYPES - 1; j >= 0; j--) {
+        int new_prob =
+            sum ? 1024 * cpi->td.rd_counts.tx_type_used[update_type][i][j] / sum
+                : (j ? 0 : 1024);
+        int prob = (cpi->tx_type_probs[update_type][i][j] + new_prob) >> 1;
+        left -= prob;
+        if (j == 0) prob += left;
+        cpi->tx_type_probs[update_type][i][j] = prob;
+      }
+    }
+  }
 }
 
 #define CHECK_PRECOMPUTED_REF_FRAME_MAP 0
diff --git a/av1/encoder/encoder.c b/av1/encoder/encoder.c
index 385cc2b..be937ef 100644
--- a/av1/encoder/encoder.c
+++ b/av1/encoder/encoder.c
@@ -100,6 +100,142 @@
 #define FILE_NAME_LEN 100
 #endif
 
+const int default_tx_type_probs[FRAME_UPDATE_TYPES][TX_SIZES_ALL][TX_TYPES] = {
+  { { 221, 189, 214, 292, 0, 0, 0, 0, 0, 2, 38, 68, 0, 0, 0, 0 },
+    { 262, 203, 216, 239, 0, 0, 0, 0, 0, 1, 37, 66, 0, 0, 0, 0 },
+    { 315, 231, 239, 226, 0, 0, 0, 0, 0, 13, 0, 0, 0, 0, 0, 0 },
+    { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+    { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+    { 222, 188, 214, 287, 0, 0, 0, 0, 0, 2, 50, 61, 0, 0, 0, 0 },
+    { 256, 182, 205, 282, 0, 0, 0, 0, 0, 2, 21, 76, 0, 0, 0, 0 },
+    { 281, 214, 217, 222, 0, 0, 0, 0, 0, 1, 48, 41, 0, 0, 0, 0 },
+    { 263, 194, 225, 225, 0, 0, 0, 0, 0, 2, 15, 100, 0, 0, 0, 0 },
+    { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+    { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+    { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+    { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+    { 170, 192, 242, 293, 0, 0, 0, 0, 0, 1, 68, 58, 0, 0, 0, 0 },
+    { 199, 210, 213, 291, 0, 0, 0, 0, 0, 1, 14, 96, 0, 0, 0, 0 },
+    { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+    { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+    { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+    { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
+  { { 106, 69, 107, 278, 9, 15, 20, 45, 49, 23, 23, 88, 36, 74, 25, 57 },
+    { 105, 72, 81, 98, 45, 49, 47, 50, 56, 72, 30, 81, 33, 95, 27, 83 },
+    { 211, 105, 109, 120, 57, 62, 43, 49, 52, 58, 42, 116, 0, 0, 0, 0 },
+    { 1008, 0, 0, 0, 0, 0, 0, 0, 0, 16, 0, 0, 0, 0, 0, 0 },
+    { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+    { 131, 57, 98, 172, 19, 40, 37, 64, 69, 22, 41, 52, 51, 77, 35, 59 },
+    { 176, 83, 93, 202, 22, 24, 28, 47, 50, 16, 12, 93, 26, 76, 17, 59 },
+    { 136, 72, 89, 95, 46, 59, 47, 56, 61, 68, 35, 51, 32, 82, 26, 69 },
+    { 122, 80, 87, 105, 49, 47, 46, 46, 57, 52, 13, 90, 19, 103, 15, 93 },
+    { 1009, 0, 0, 0, 0, 0, 0, 0, 0, 15, 0, 0, 0, 0, 0, 0 },
+    { 1011, 0, 0, 0, 0, 0, 0, 0, 0, 13, 0, 0, 0, 0, 0, 0 },
+    { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+    { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+    { 202, 20, 84, 114, 14, 60, 41, 79, 99, 21, 41, 15, 50, 84, 34, 66 },
+    { 196, 44, 23, 72, 30, 22, 28, 57, 67, 13, 4, 165, 15, 148, 9, 131 },
+    { 882, 0, 0, 0, 0, 0, 0, 0, 0, 142, 0, 0, 0, 0, 0, 0 },
+    { 840, 0, 0, 0, 0, 0, 0, 0, 0, 184, 0, 0, 0, 0, 0, 0 },
+    { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+    { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
+  { { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+    { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+    { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+    { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+    { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+    { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+    { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+    { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+    { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+    { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+    { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+    { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+    { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+    { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+    { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+    { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+    { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+    { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+    { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 } },
+  { { 213, 110, 141, 269, 12, 16, 15, 19, 21, 11, 38, 68, 22, 29, 16, 24 },
+    { 216, 119, 128, 143, 38, 41, 26, 30, 31, 30, 42, 70, 23, 36, 19, 32 },
+    { 367, 149, 154, 154, 38, 35, 17, 21, 21, 10, 22, 36, 0, 0, 0, 0 },
+    { 1022, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0 },
+    { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+    { 219, 96, 127, 191, 21, 40, 25, 32, 34, 18, 45, 45, 33, 39, 26, 33 },
+    { 296, 99, 122, 198, 23, 21, 19, 24, 25, 13, 20, 64, 23, 32, 18, 27 },
+    { 275, 128, 142, 143, 35, 48, 23, 30, 29, 18, 42, 36, 18, 23, 14, 20 },
+    { 239, 132, 166, 175, 36, 27, 19, 21, 24, 14, 13, 85, 9, 31, 8, 25 },
+    { 1022, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0 },
+    { 1022, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0 },
+    { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+    { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+    { 309, 25, 79, 59, 25, 80, 34, 53, 61, 25, 49, 23, 43, 64, 36, 59 },
+    { 270, 57, 40, 54, 50, 42, 41, 53, 56, 28, 17, 81, 45, 86, 34, 70 },
+    { 1005, 0, 0, 0, 0, 0, 0, 0, 0, 19, 0, 0, 0, 0, 0, 0 },
+    { 992, 0, 0, 0, 0, 0, 0, 0, 0, 32, 0, 0, 0, 0, 0, 0 },
+    { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+    { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
+  { { 133, 63, 55, 83, 57, 87, 58, 72, 68, 16, 24, 35, 29, 105, 25, 114 },
+    { 131, 75, 74, 60, 71, 77, 65, 66, 73, 33, 21, 79, 20, 83, 18, 78 },
+    { 276, 95, 82, 58, 86, 93, 63, 60, 64, 17, 38, 92, 0, 0, 0, 0 },
+    { 1006, 0, 0, 0, 0, 0, 0, 0, 0, 18, 0, 0, 0, 0, 0, 0 },
+    { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+    { 147, 49, 75, 78, 50, 97, 60, 67, 76, 17, 42, 35, 31, 93, 27, 80 },
+    { 157, 49, 58, 75, 61, 52, 56, 67, 69, 12, 15, 79, 24, 119, 11, 120 },
+    { 178, 69, 83, 77, 69, 85, 72, 77, 77, 20, 35, 40, 25, 48, 23, 46 },
+    { 174, 55, 64, 57, 73, 68, 62, 61, 75, 15, 12, 90, 17, 99, 16, 86 },
+    { 1008, 0, 0, 0, 0, 0, 0, 0, 0, 16, 0, 0, 0, 0, 0, 0 },
+    { 1018, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0 },
+    { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+    { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+    { 266, 31, 63, 64, 21, 52, 39, 54, 63, 30, 52, 31, 48, 89, 46, 75 },
+    { 272, 26, 32, 44, 29, 31, 32, 53, 51, 13, 13, 88, 22, 153, 16, 149 },
+    { 923, 0, 0, 0, 0, 0, 0, 0, 0, 101, 0, 0, 0, 0, 0, 0 },
+    { 969, 0, 0, 0, 0, 0, 0, 0, 0, 55, 0, 0, 0, 0, 0, 0 },
+    { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+    { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
+  { { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+    { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+    { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+    { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+    { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+    { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+    { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+    { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+    { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+    { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+    { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+    { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+    { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+    { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+    { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+    { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+    { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+    { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+    { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 } },
+  { { 158, 92, 125, 298, 12, 15, 20, 29, 31, 12, 29, 67, 34, 44, 23, 35 },
+    { 147, 94, 103, 123, 45, 48, 38, 41, 46, 48, 37, 78, 33, 63, 27, 53 },
+    { 268, 126, 125, 136, 54, 53, 31, 38, 38, 33, 35, 87, 0, 0, 0, 0 },
+    { 1018, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0 },
+    { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+    { 159, 72, 103, 194, 20, 35, 37, 50, 56, 21, 39, 40, 51, 61, 38, 48 },
+    { 259, 86, 95, 188, 32, 20, 25, 34, 37, 13, 12, 85, 25, 53, 17, 43 },
+    { 189, 99, 113, 123, 45, 59, 37, 46, 48, 44, 39, 41, 31, 47, 26, 37 },
+    { 175, 110, 113, 128, 58, 38, 33, 33, 43, 29, 13, 100, 14, 68, 12, 57 },
+    { 1017, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0 },
+    { 1019, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0 },
+    { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+    { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+    { 208, 22, 84, 101, 21, 59, 44, 70, 90, 25, 59, 13, 64, 67, 49, 48 },
+    { 277, 52, 32, 63, 43, 26, 33, 48, 54, 11, 6, 130, 18, 119, 11, 101 },
+    { 963, 0, 0, 0, 0, 0, 0, 0, 0, 61, 0, 0, 0, 0, 0, 0 },
+    { 979, 0, 0, 0, 0, 0, 0, 0, 0, 45, 0, 0, 0, 0, 0, 0 },
+    { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+    { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } }
+};
+
 static INLINE void Scale2Ratio(AOM_SCALING mode, int *hr, int *hs) {
   switch (mode) {
     case NORMAL:
@@ -4827,6 +4963,18 @@
   q_low = bottom_index;
   q_high = top_index;
 
+  if (cm->current_frame.frame_type == KEY_FRAME) {
+    av1_copy(cpi->tx_type_probs, default_tx_type_probs);
+
+    for (int f = 0; f < FRAME_UPDATE_TYPES; f++) {
+      // TODO(yunqing): Threshold can be updated adaptively for 1 frame.
+      if (f == KF_UPDATE || f == ARF_UPDATE)
+        cpi->tx_type_probs_thresh[f] = 10;
+      else
+        cpi->tx_type_probs_thresh[f] = 17;
+    }
+  }
+
   // Loop variables
   int loop_count = 0;
   int loop_at_this_size = 0;
diff --git a/av1/encoder/encoder.h b/av1/encoder/encoder.h
index 53cc6a2..fe91842 100644
--- a/av1/encoder/encoder.h
+++ b/av1/encoder/encoder.h
@@ -603,6 +603,7 @@
   int global_motion_used[REF_FRAMES];
   int compound_ref_used_flag;
   int skip_mode_used_flag;
+  int tx_type_used[FRAME_UPDATE_TYPES][TX_SIZES_ALL][TX_TYPES];
 } RD_COUNTS;
 
 typedef struct ThreadData {
@@ -978,6 +979,9 @@
   int64_t vbp_threshold_copy;
   BLOCK_SIZE vbp_bsize_min;
 
+  int tx_type_probs[FRAME_UPDATE_TYPES][TX_SIZES_ALL][TX_TYPES];
+  int tx_type_probs_thresh[FRAME_UPDATE_TYPES];
+
   // Multi-threading
   int num_workers;
   AVxWorker *workers;
@@ -1444,6 +1448,23 @@
 #endif  // ENABLE_KF_TPL
 }
 
+// Get update type of the current frame.
+static INLINE FRAME_UPDATE_TYPE get_frame_update_type(const AV1_COMP *cpi) {
+  const GF_GROUP *const gf_group = &cpi->gf_group;
+  if (gf_group->size == 0) {
+    // Special case 1: happens at the first frame of a video.
+    return KF_UPDATE;
+  }
+  if (gf_group->index == gf_group->size) {
+    // Special case 2: happens at the start of next GF group, or at the end of
+    // the key-frame group. So, not marked in gf_group->update_type array, but
+    // can be inferred implicitly.
+    return cpi->rc.source_alt_ref_active ? OVERLAY_UPDATE : GF_UPDATE;
+  }
+  // General case.
+  return gf_group->update_type[gf_group->index];
+}
+
 #if CONFIG_COLLECT_PARTITION_STATS == 2
 static INLINE void av1_print_partition_stats(PartitionStats *part_stats) {
   FILE *f = fopen("partition_stats.csv", "w");
diff --git a/av1/encoder/encodetxb.c b/av1/encoder/encodetxb.c
index 20956a8..deb585a 100644
--- a/av1/encoder/encodetxb.c
+++ b/av1/encoder/encodetxb.c
@@ -2073,6 +2073,10 @@
   const TX_CLASS tx_class = tx_type_to_class[tx_type];
   const SCAN_ORDER *const scan_order = get_scan(tx_size, tx_type);
   const int16_t *const scan = scan_order->scan;
+
+  // record tx type usage
+  td->rd_counts.tx_type_used[get_frame_update_type(cpi)][tx_size][tx_type]++;
+
 #if CONFIG_ENTROPY_STATS
   av1_update_eob_context(cdf_idx, eob, tx_size, tx_class, plane_type, ec_ctx,
                          td->counts, allow_update_cdf);
diff --git a/av1/encoder/ethread.c b/av1/encoder/ethread.c
index 18a2ef3..1856d69 100644
--- a/av1/encoder/ethread.c
+++ b/av1/encoder/ethread.c
@@ -27,6 +27,14 @@
   td->rd_counts.compound_ref_used_flag |=
       td_t->rd_counts.compound_ref_used_flag;
   td->rd_counts.skip_mode_used_flag |= td_t->rd_counts.skip_mode_used_flag;
+
+  for (int i = 0; i < FRAME_UPDATE_TYPES; i++) {
+    for (int j = 0; j < TX_SIZES_ALL; j++) {
+      for (int k = 0; k < TX_TYPES; k++)
+        td->rd_counts.tx_type_used[i][j][k] +=
+            td_t->rd_counts.tx_type_used[i][j][k];
+    }
+  }
 }
 
 static void update_delta_lf_for_row_mt(AV1_COMP *cpi) {
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 2b5af99..3476a37 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -1626,7 +1626,8 @@
 
 static uint16_t prune_tx_2D(MACROBLOCK *x, BLOCK_SIZE bsize, TX_SIZE tx_size,
                             int blk_row, int blk_col, TxSetType tx_set_type,
-                            TX_TYPE_PRUNE_MODE prune_mode, int *txk_map) {
+                            TX_TYPE_PRUNE_MODE prune_mode, int *txk_map,
+                            uint16_t allowed_tx_mask) {
   int tx_type_table_2D[16] = {
     DCT_DCT,      DCT_ADST,      DCT_FLIPADST,      V_DCT,
     ADST_DCT,     ADST_ADST,     ADST_FLIPADST,     V_ADST,
@@ -1701,7 +1702,7 @@
   float max_score = 0.0f;
   for (int i = 0; i < 16; i++) {
     if (scores_2D[i] > max_score &&
-        av1_ext_tx_used[tx_set_type][tx_type_table_2D[i]]) {
+        (allowed_tx_mask & (1 << tx_type_table_2D[i]))) {
       max_score = scores_2D[i];
       max_score_i = i;
     }
@@ -3080,11 +3081,41 @@
   } else {
     assert(plane == 0);
     allowed_tx_mask = ext_tx_used_flag;
+    int num_allowed = 0;
+    const FRAME_UPDATE_TYPE update_type = get_frame_update_type(cpi);
+    const int *tx_type_probs = cpi->tx_type_probs[update_type][tx_size];
+    int i;
+
+    if (cpi->sf.tx_type_search.prune_tx_type_using_stats) {
+      const int thresh = cpi->tx_type_probs_thresh[update_type];
+      uint16_t prune = 0;
+      int max_prob = -1;
+      int max_idx = 0;
+      for (i = 0; i < TX_TYPES; i++) {
+        if (tx_type_probs[i] > max_prob && (allowed_tx_mask & (1 << i))) {
+          max_prob = tx_type_probs[i];
+          max_idx = i;
+        }
+      }
+
+      for (i = 0; i < TX_TYPES; i++) {
+        if (tx_type_probs[i] < thresh && i != max_idx) prune |= (1 << i);
+      }
+      allowed_tx_mask &= (~prune);
+    }
+
+    for (i = 0; i < TX_TYPES; i++) {
+      if (allowed_tx_mask & (1 << i)) num_allowed++;
+    }
+    assert(num_allowed > 0);
+
+    // Go through ML model only if num_allowed > 5.
     // !fast_tx_search && txk_end != txk_start && plane == 0
-    if (cpi->sf.tx_type_search.prune_mode >= PRUNE_2D_ACCURATE && is_inter) {
-      const uint16_t prune =
-          prune_tx_2D(x, plane_bsize, tx_size, blk_row, blk_col, tx_set_type,
-                      cpi->sf.tx_type_search.prune_mode, txk_map);
+    if (cpi->sf.tx_type_search.prune_mode >= PRUNE_2D_ACCURATE && is_inter &&
+        num_allowed > 5) {
+      const uint16_t prune = prune_tx_2D(
+          x, plane_bsize, tx_size, blk_row, blk_col, tx_set_type,
+          cpi->sf.tx_type_search.prune_mode, txk_map, allowed_tx_mask);
       allowed_tx_mask &= (~prune);
     }
   }
diff --git a/av1/encoder/speed_features.c b/av1/encoder/speed_features.c
index 20386d9..e895df9 100644
--- a/av1/encoder/speed_features.c
+++ b/av1/encoder/speed_features.c
@@ -387,6 +387,7 @@
     sf->perform_coeff_opt = is_boosted_arf2_bwd_type ? 2 : 4;
     sf->adaptive_txb_search_level = boosted ? 2 : 3;
     sf->mv.subpel_search_method = SUBPEL_TREE_PRUNED_MORE;
+    sf->tx_type_search.prune_tx_type_using_stats = 1;
   }
 
   if (speed >= 5) {
@@ -757,6 +758,7 @@
   sf->tx_type_search.fast_intra_tx_type_search = 0;
   sf->tx_type_search.fast_inter_tx_type_search = 0;
   sf->tx_type_search.skip_tx_search = 0;
+  sf->tx_type_search.prune_tx_type_using_stats = 0;
   sf->selective_ref_frame = 0;
   sf->less_rectangular_check_level = 0;
   sf->use_square_partition_only_threshold = BLOCK_128X128;
diff --git a/av1/encoder/speed_features.h b/av1/encoder/speed_features.h
index b2a3e44..75e76ca 100644
--- a/av1/encoder/speed_features.h
+++ b/av1/encoder/speed_features.h
@@ -183,6 +183,9 @@
   // skip remaining transform type search when we found the rdcost of skip is
   // better than applying transform
   int skip_tx_search;
+
+  // Prune tx type search using previous frame stats.
+  int prune_tx_type_using_stats;
 } TX_TYPE_SEARCH;
 
 enum {
diff --git a/test/gf_max_pyr_height_test.cc b/test/gf_max_pyr_height_test.cc
index 28b363b..b722168 100644
--- a/test/gf_max_pyr_height_test.cc
+++ b/test/gf_max_pyr_height_test.cc
@@ -21,7 +21,7 @@
   int gf_max_pyr_height;
   double psnr_thresh;
 } kTestParams[] = {
-  { 0, 34.5 }, { 1, 34.75 }, { 2, 35.25 }, { 3, 35.50 }, { 4, 35.50 },
+  { 0, 34.5 }, { 1, 34.75 }, { 2, 35.22 }, { 3, 35.50 }, { 4, 35.50 },
 };
 
 // Compiler may decide to add some padding to the struct above for alignment,