Prune tx type search using previous stats
Record tx type usage and using the stats to prune following frame's
tx type search. Borg test result for speed 4:
avg_psnr: ovr_psnr: ssim:
midres: 0.120 0.107 -0.034
Average speedup over whole midres set is 3.2%.
Will work on extending this feature to other speeds.
Change-Id: I864e1e4515c0f707aa48a3c0f9ed67b573603739
diff --git a/av1/encoder/encode_strategy.c b/av1/encoder/encode_strategy.c
index b6def7a..2e56077 100644
--- a/av1/encoder/encode_strategy.c
+++ b/av1/encoder/encode_strategy.c
@@ -182,23 +182,6 @@
update_gf_group_index(cpi);
}
-// Get update type of the current frame.
-static INLINE FRAME_UPDATE_TYPE get_frame_update_type(const AV1_COMP *cpi) {
- const GF_GROUP *const gf_group = &cpi->gf_group;
- if (gf_group->size == 0) {
- // Special case 1: happens at the first frame of a video.
- return KF_UPDATE;
- }
- if (gf_group->index == gf_group->size) {
- // Special case 2: happens at the start of next GF group, or at the end of
- // the key-frame group. So, not marked in gf_group->update_type array, but
- // can be inferred implicitly.
- return cpi->rc.source_alt_ref_active ? OVERLAY_UPDATE : GF_UPDATE;
- }
- // General case.
- return gf_group->update_type[gf_group->index];
-}
-
static void set_ext_overrides(AV1_COMP *const cpi,
EncodeFrameParams *const frame_params) {
// Overrides the defaults with the externally supplied values with
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index 75a09bf..7e0c395 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -4613,6 +4613,7 @@
av1_zero(*td->counts);
av1_zero(rdc->comp_pred_diff);
+ av1_zero(rdc->tx_type_used);
// Reset the flag.
cpi->intrabc_used = 0;
@@ -4987,6 +4988,29 @@
if (cm->delta_q_info.delta_q_present_flag && cpi->deltaq_used == 0) {
cm->delta_q_info.delta_q_present_flag = 0;
}
+
+ if (cpi->sf.tx_type_search.prune_tx_type_using_stats) {
+ const FRAME_UPDATE_TYPE update_type = get_frame_update_type(cpi);
+
+ for (i = 0; i < TX_SIZES_ALL; i++) {
+ int sum = 0;
+ int j;
+ int left = 1024;
+
+ for (j = 0; j < TX_TYPES; j++)
+ sum += cpi->td.rd_counts.tx_type_used[update_type][i][j];
+
+ for (j = TX_TYPES - 1; j >= 0; j--) {
+ int new_prob =
+ sum ? 1024 * cpi->td.rd_counts.tx_type_used[update_type][i][j] / sum
+ : (j ? 0 : 1024);
+ int prob = (cpi->tx_type_probs[update_type][i][j] + new_prob) >> 1;
+ left -= prob;
+ if (j == 0) prob += left;
+ cpi->tx_type_probs[update_type][i][j] = prob;
+ }
+ }
+ }
}
#define CHECK_PRECOMPUTED_REF_FRAME_MAP 0
diff --git a/av1/encoder/encoder.c b/av1/encoder/encoder.c
index 385cc2b..be937ef 100644
--- a/av1/encoder/encoder.c
+++ b/av1/encoder/encoder.c
@@ -100,6 +100,142 @@
#define FILE_NAME_LEN 100
#endif
+const int default_tx_type_probs[FRAME_UPDATE_TYPES][TX_SIZES_ALL][TX_TYPES] = {
+ { { 221, 189, 214, 292, 0, 0, 0, 0, 0, 2, 38, 68, 0, 0, 0, 0 },
+ { 262, 203, 216, 239, 0, 0, 0, 0, 0, 1, 37, 66, 0, 0, 0, 0 },
+ { 315, 231, 239, 226, 0, 0, 0, 0, 0, 13, 0, 0, 0, 0, 0, 0 },
+ { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+ { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+ { 222, 188, 214, 287, 0, 0, 0, 0, 0, 2, 50, 61, 0, 0, 0, 0 },
+ { 256, 182, 205, 282, 0, 0, 0, 0, 0, 2, 21, 76, 0, 0, 0, 0 },
+ { 281, 214, 217, 222, 0, 0, 0, 0, 0, 1, 48, 41, 0, 0, 0, 0 },
+ { 263, 194, 225, 225, 0, 0, 0, 0, 0, 2, 15, 100, 0, 0, 0, 0 },
+ { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+ { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+ { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+ { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+ { 170, 192, 242, 293, 0, 0, 0, 0, 0, 1, 68, 58, 0, 0, 0, 0 },
+ { 199, 210, 213, 291, 0, 0, 0, 0, 0, 1, 14, 96, 0, 0, 0, 0 },
+ { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+ { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+ { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+ { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
+ { { 106, 69, 107, 278, 9, 15, 20, 45, 49, 23, 23, 88, 36, 74, 25, 57 },
+ { 105, 72, 81, 98, 45, 49, 47, 50, 56, 72, 30, 81, 33, 95, 27, 83 },
+ { 211, 105, 109, 120, 57, 62, 43, 49, 52, 58, 42, 116, 0, 0, 0, 0 },
+ { 1008, 0, 0, 0, 0, 0, 0, 0, 0, 16, 0, 0, 0, 0, 0, 0 },
+ { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+ { 131, 57, 98, 172, 19, 40, 37, 64, 69, 22, 41, 52, 51, 77, 35, 59 },
+ { 176, 83, 93, 202, 22, 24, 28, 47, 50, 16, 12, 93, 26, 76, 17, 59 },
+ { 136, 72, 89, 95, 46, 59, 47, 56, 61, 68, 35, 51, 32, 82, 26, 69 },
+ { 122, 80, 87, 105, 49, 47, 46, 46, 57, 52, 13, 90, 19, 103, 15, 93 },
+ { 1009, 0, 0, 0, 0, 0, 0, 0, 0, 15, 0, 0, 0, 0, 0, 0 },
+ { 1011, 0, 0, 0, 0, 0, 0, 0, 0, 13, 0, 0, 0, 0, 0, 0 },
+ { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+ { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+ { 202, 20, 84, 114, 14, 60, 41, 79, 99, 21, 41, 15, 50, 84, 34, 66 },
+ { 196, 44, 23, 72, 30, 22, 28, 57, 67, 13, 4, 165, 15, 148, 9, 131 },
+ { 882, 0, 0, 0, 0, 0, 0, 0, 0, 142, 0, 0, 0, 0, 0, 0 },
+ { 840, 0, 0, 0, 0, 0, 0, 0, 0, 184, 0, 0, 0, 0, 0, 0 },
+ { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+ { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
+ { { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+ { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+ { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+ { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+ { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+ { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+ { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+ { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+ { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+ { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+ { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+ { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+ { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+ { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+ { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+ { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+ { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+ { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+ { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 } },
+ { { 213, 110, 141, 269, 12, 16, 15, 19, 21, 11, 38, 68, 22, 29, 16, 24 },
+ { 216, 119, 128, 143, 38, 41, 26, 30, 31, 30, 42, 70, 23, 36, 19, 32 },
+ { 367, 149, 154, 154, 38, 35, 17, 21, 21, 10, 22, 36, 0, 0, 0, 0 },
+ { 1022, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0 },
+ { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+ { 219, 96, 127, 191, 21, 40, 25, 32, 34, 18, 45, 45, 33, 39, 26, 33 },
+ { 296, 99, 122, 198, 23, 21, 19, 24, 25, 13, 20, 64, 23, 32, 18, 27 },
+ { 275, 128, 142, 143, 35, 48, 23, 30, 29, 18, 42, 36, 18, 23, 14, 20 },
+ { 239, 132, 166, 175, 36, 27, 19, 21, 24, 14, 13, 85, 9, 31, 8, 25 },
+ { 1022, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0 },
+ { 1022, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0 },
+ { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+ { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+ { 309, 25, 79, 59, 25, 80, 34, 53, 61, 25, 49, 23, 43, 64, 36, 59 },
+ { 270, 57, 40, 54, 50, 42, 41, 53, 56, 28, 17, 81, 45, 86, 34, 70 },
+ { 1005, 0, 0, 0, 0, 0, 0, 0, 0, 19, 0, 0, 0, 0, 0, 0 },
+ { 992, 0, 0, 0, 0, 0, 0, 0, 0, 32, 0, 0, 0, 0, 0, 0 },
+ { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+ { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
+ { { 133, 63, 55, 83, 57, 87, 58, 72, 68, 16, 24, 35, 29, 105, 25, 114 },
+ { 131, 75, 74, 60, 71, 77, 65, 66, 73, 33, 21, 79, 20, 83, 18, 78 },
+ { 276, 95, 82, 58, 86, 93, 63, 60, 64, 17, 38, 92, 0, 0, 0, 0 },
+ { 1006, 0, 0, 0, 0, 0, 0, 0, 0, 18, 0, 0, 0, 0, 0, 0 },
+ { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+ { 147, 49, 75, 78, 50, 97, 60, 67, 76, 17, 42, 35, 31, 93, 27, 80 },
+ { 157, 49, 58, 75, 61, 52, 56, 67, 69, 12, 15, 79, 24, 119, 11, 120 },
+ { 178, 69, 83, 77, 69, 85, 72, 77, 77, 20, 35, 40, 25, 48, 23, 46 },
+ { 174, 55, 64, 57, 73, 68, 62, 61, 75, 15, 12, 90, 17, 99, 16, 86 },
+ { 1008, 0, 0, 0, 0, 0, 0, 0, 0, 16, 0, 0, 0, 0, 0, 0 },
+ { 1018, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0 },
+ { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+ { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+ { 266, 31, 63, 64, 21, 52, 39, 54, 63, 30, 52, 31, 48, 89, 46, 75 },
+ { 272, 26, 32, 44, 29, 31, 32, 53, 51, 13, 13, 88, 22, 153, 16, 149 },
+ { 923, 0, 0, 0, 0, 0, 0, 0, 0, 101, 0, 0, 0, 0, 0, 0 },
+ { 969, 0, 0, 0, 0, 0, 0, 0, 0, 55, 0, 0, 0, 0, 0, 0 },
+ { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+ { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
+ { { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+ { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+ { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+ { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+ { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+ { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+ { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+ { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+ { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+ { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+ { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+ { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+ { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+ { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+ { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+ { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+ { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+ { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 },
+ { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 } },
+ { { 158, 92, 125, 298, 12, 15, 20, 29, 31, 12, 29, 67, 34, 44, 23, 35 },
+ { 147, 94, 103, 123, 45, 48, 38, 41, 46, 48, 37, 78, 33, 63, 27, 53 },
+ { 268, 126, 125, 136, 54, 53, 31, 38, 38, 33, 35, 87, 0, 0, 0, 0 },
+ { 1018, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0 },
+ { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+ { 159, 72, 103, 194, 20, 35, 37, 50, 56, 21, 39, 40, 51, 61, 38, 48 },
+ { 259, 86, 95, 188, 32, 20, 25, 34, 37, 13, 12, 85, 25, 53, 17, 43 },
+ { 189, 99, 113, 123, 45, 59, 37, 46, 48, 44, 39, 41, 31, 47, 26, 37 },
+ { 175, 110, 113, 128, 58, 38, 33, 33, 43, 29, 13, 100, 14, 68, 12, 57 },
+ { 1017, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0 },
+ { 1019, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0 },
+ { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+ { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+ { 208, 22, 84, 101, 21, 59, 44, 70, 90, 25, 59, 13, 64, 67, 49, 48 },
+ { 277, 52, 32, 63, 43, 26, 33, 48, 54, 11, 6, 130, 18, 119, 11, 101 },
+ { 963, 0, 0, 0, 0, 0, 0, 0, 0, 61, 0, 0, 0, 0, 0, 0 },
+ { 979, 0, 0, 0, 0, 0, 0, 0, 0, 45, 0, 0, 0, 0, 0, 0 },
+ { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+ { 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } }
+};
+
static INLINE void Scale2Ratio(AOM_SCALING mode, int *hr, int *hs) {
switch (mode) {
case NORMAL:
@@ -4827,6 +4963,18 @@
q_low = bottom_index;
q_high = top_index;
+ if (cm->current_frame.frame_type == KEY_FRAME) {
+ av1_copy(cpi->tx_type_probs, default_tx_type_probs);
+
+ for (int f = 0; f < FRAME_UPDATE_TYPES; f++) {
+ // TODO(yunqing): Threshold can be updated adaptively for 1 frame.
+ if (f == KF_UPDATE || f == ARF_UPDATE)
+ cpi->tx_type_probs_thresh[f] = 10;
+ else
+ cpi->tx_type_probs_thresh[f] = 17;
+ }
+ }
+
// Loop variables
int loop_count = 0;
int loop_at_this_size = 0;
diff --git a/av1/encoder/encoder.h b/av1/encoder/encoder.h
index 53cc6a2..fe91842 100644
--- a/av1/encoder/encoder.h
+++ b/av1/encoder/encoder.h
@@ -603,6 +603,7 @@
int global_motion_used[REF_FRAMES];
int compound_ref_used_flag;
int skip_mode_used_flag;
+ int tx_type_used[FRAME_UPDATE_TYPES][TX_SIZES_ALL][TX_TYPES];
} RD_COUNTS;
typedef struct ThreadData {
@@ -978,6 +979,9 @@
int64_t vbp_threshold_copy;
BLOCK_SIZE vbp_bsize_min;
+ int tx_type_probs[FRAME_UPDATE_TYPES][TX_SIZES_ALL][TX_TYPES];
+ int tx_type_probs_thresh[FRAME_UPDATE_TYPES];
+
// Multi-threading
int num_workers;
AVxWorker *workers;
@@ -1444,6 +1448,23 @@
#endif // ENABLE_KF_TPL
}
+// Get update type of the current frame.
+static INLINE FRAME_UPDATE_TYPE get_frame_update_type(const AV1_COMP *cpi) {
+ const GF_GROUP *const gf_group = &cpi->gf_group;
+ if (gf_group->size == 0) {
+ // Special case 1: happens at the first frame of a video.
+ return KF_UPDATE;
+ }
+ if (gf_group->index == gf_group->size) {
+ // Special case 2: happens at the start of next GF group, or at the end of
+ // the key-frame group. So, not marked in gf_group->update_type array, but
+ // can be inferred implicitly.
+ return cpi->rc.source_alt_ref_active ? OVERLAY_UPDATE : GF_UPDATE;
+ }
+ // General case.
+ return gf_group->update_type[gf_group->index];
+}
+
#if CONFIG_COLLECT_PARTITION_STATS == 2
static INLINE void av1_print_partition_stats(PartitionStats *part_stats) {
FILE *f = fopen("partition_stats.csv", "w");
diff --git a/av1/encoder/encodetxb.c b/av1/encoder/encodetxb.c
index 20956a8..deb585a 100644
--- a/av1/encoder/encodetxb.c
+++ b/av1/encoder/encodetxb.c
@@ -2073,6 +2073,10 @@
const TX_CLASS tx_class = tx_type_to_class[tx_type];
const SCAN_ORDER *const scan_order = get_scan(tx_size, tx_type);
const int16_t *const scan = scan_order->scan;
+
+ // record tx type usage
+ td->rd_counts.tx_type_used[get_frame_update_type(cpi)][tx_size][tx_type]++;
+
#if CONFIG_ENTROPY_STATS
av1_update_eob_context(cdf_idx, eob, tx_size, tx_class, plane_type, ec_ctx,
td->counts, allow_update_cdf);
diff --git a/av1/encoder/ethread.c b/av1/encoder/ethread.c
index 18a2ef3..1856d69 100644
--- a/av1/encoder/ethread.c
+++ b/av1/encoder/ethread.c
@@ -27,6 +27,14 @@
td->rd_counts.compound_ref_used_flag |=
td_t->rd_counts.compound_ref_used_flag;
td->rd_counts.skip_mode_used_flag |= td_t->rd_counts.skip_mode_used_flag;
+
+ for (int i = 0; i < FRAME_UPDATE_TYPES; i++) {
+ for (int j = 0; j < TX_SIZES_ALL; j++) {
+ for (int k = 0; k < TX_TYPES; k++)
+ td->rd_counts.tx_type_used[i][j][k] +=
+ td_t->rd_counts.tx_type_used[i][j][k];
+ }
+ }
}
static void update_delta_lf_for_row_mt(AV1_COMP *cpi) {
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 2b5af99..3476a37 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -1626,7 +1626,8 @@
static uint16_t prune_tx_2D(MACROBLOCK *x, BLOCK_SIZE bsize, TX_SIZE tx_size,
int blk_row, int blk_col, TxSetType tx_set_type,
- TX_TYPE_PRUNE_MODE prune_mode, int *txk_map) {
+ TX_TYPE_PRUNE_MODE prune_mode, int *txk_map,
+ uint16_t allowed_tx_mask) {
int tx_type_table_2D[16] = {
DCT_DCT, DCT_ADST, DCT_FLIPADST, V_DCT,
ADST_DCT, ADST_ADST, ADST_FLIPADST, V_ADST,
@@ -1701,7 +1702,7 @@
float max_score = 0.0f;
for (int i = 0; i < 16; i++) {
if (scores_2D[i] > max_score &&
- av1_ext_tx_used[tx_set_type][tx_type_table_2D[i]]) {
+ (allowed_tx_mask & (1 << tx_type_table_2D[i]))) {
max_score = scores_2D[i];
max_score_i = i;
}
@@ -3080,11 +3081,41 @@
} else {
assert(plane == 0);
allowed_tx_mask = ext_tx_used_flag;
+ int num_allowed = 0;
+ const FRAME_UPDATE_TYPE update_type = get_frame_update_type(cpi);
+ const int *tx_type_probs = cpi->tx_type_probs[update_type][tx_size];
+ int i;
+
+ if (cpi->sf.tx_type_search.prune_tx_type_using_stats) {
+ const int thresh = cpi->tx_type_probs_thresh[update_type];
+ uint16_t prune = 0;
+ int max_prob = -1;
+ int max_idx = 0;
+ for (i = 0; i < TX_TYPES; i++) {
+ if (tx_type_probs[i] > max_prob && (allowed_tx_mask & (1 << i))) {
+ max_prob = tx_type_probs[i];
+ max_idx = i;
+ }
+ }
+
+ for (i = 0; i < TX_TYPES; i++) {
+ if (tx_type_probs[i] < thresh && i != max_idx) prune |= (1 << i);
+ }
+ allowed_tx_mask &= (~prune);
+ }
+
+ for (i = 0; i < TX_TYPES; i++) {
+ if (allowed_tx_mask & (1 << i)) num_allowed++;
+ }
+ assert(num_allowed > 0);
+
+ // Go through ML model only if num_allowed > 5.
// !fast_tx_search && txk_end != txk_start && plane == 0
- if (cpi->sf.tx_type_search.prune_mode >= PRUNE_2D_ACCURATE && is_inter) {
- const uint16_t prune =
- prune_tx_2D(x, plane_bsize, tx_size, blk_row, blk_col, tx_set_type,
- cpi->sf.tx_type_search.prune_mode, txk_map);
+ if (cpi->sf.tx_type_search.prune_mode >= PRUNE_2D_ACCURATE && is_inter &&
+ num_allowed > 5) {
+ const uint16_t prune = prune_tx_2D(
+ x, plane_bsize, tx_size, blk_row, blk_col, tx_set_type,
+ cpi->sf.tx_type_search.prune_mode, txk_map, allowed_tx_mask);
allowed_tx_mask &= (~prune);
}
}
diff --git a/av1/encoder/speed_features.c b/av1/encoder/speed_features.c
index 20386d9..e895df9 100644
--- a/av1/encoder/speed_features.c
+++ b/av1/encoder/speed_features.c
@@ -387,6 +387,7 @@
sf->perform_coeff_opt = is_boosted_arf2_bwd_type ? 2 : 4;
sf->adaptive_txb_search_level = boosted ? 2 : 3;
sf->mv.subpel_search_method = SUBPEL_TREE_PRUNED_MORE;
+ sf->tx_type_search.prune_tx_type_using_stats = 1;
}
if (speed >= 5) {
@@ -757,6 +758,7 @@
sf->tx_type_search.fast_intra_tx_type_search = 0;
sf->tx_type_search.fast_inter_tx_type_search = 0;
sf->tx_type_search.skip_tx_search = 0;
+ sf->tx_type_search.prune_tx_type_using_stats = 0;
sf->selective_ref_frame = 0;
sf->less_rectangular_check_level = 0;
sf->use_square_partition_only_threshold = BLOCK_128X128;
diff --git a/av1/encoder/speed_features.h b/av1/encoder/speed_features.h
index b2a3e44..75e76ca 100644
--- a/av1/encoder/speed_features.h
+++ b/av1/encoder/speed_features.h
@@ -183,6 +183,9 @@
// skip remaining transform type search when we found the rdcost of skip is
// better than applying transform
int skip_tx_search;
+
+ // Prune tx type search using previous frame stats.
+ int prune_tx_type_using_stats;
} TX_TYPE_SEARCH;
enum {
diff --git a/test/gf_max_pyr_height_test.cc b/test/gf_max_pyr_height_test.cc
index 28b363b..b722168 100644
--- a/test/gf_max_pyr_height_test.cc
+++ b/test/gf_max_pyr_height_test.cc
@@ -21,7 +21,7 @@
int gf_max_pyr_height;
double psnr_thresh;
} kTestParams[] = {
- { 0, 34.5 }, { 1, 34.75 }, { 2, 35.25 }, { 3, 35.50 }, { 4, 35.50 },
+ { 0, 34.5 }, { 1, 34.75 }, { 2, 35.22 }, { 3, 35.50 }, { 4, 35.50 },
};
// Compiler may decide to add some padding to the struct above for alignment,