Enable tx type pruning based on probability score for winner mode

This patch will introduce a prune tx mode to prune tx type based on cumulative
probability given by the ML model and number of allowed transforms.
This speed feature is applicable for inter frames and for speed 4.

            Instruction Count
cpu-used       Reduction        Quality Loss
    4            1.198%           -0.068%

STATS_CHANGED

Change-Id: Idf837a07d7fa1c4e122a171033757f47ce7ec49b
diff --git a/av1/encoder/block.h b/av1/encoder/block.h
index c99da91..76127fd 100644
--- a/av1/encoder/block.h
+++ b/av1/encoder/block.h
@@ -413,6 +413,7 @@
   int must_find_valid_partition;
   int recalc_luma_mc_data;  // Flag to indicate recalculation of MC data during
                             // interpolation filter search
+  int prune_mode;
   uint32_t tx_domain_dist_threshold;
   int use_transform_domain_distortion;
   // The likelihood of an edge existing in the block (using partial Canny edge
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 85dc74a..7c8fd21 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -1878,10 +1878,10 @@
   }
 }
 
-static uint16_t prune_tx_2D(MACROBLOCK *x, BLOCK_SIZE bsize, TX_SIZE tx_size,
-                            int blk_row, int blk_col, TxSetType tx_set_type,
-                            TX_TYPE_PRUNE_MODE prune_mode, int *txk_map,
-                            uint16_t allowed_tx_mask) {
+static void prune_tx_2D(MACROBLOCK *x, BLOCK_SIZE bsize, TX_SIZE tx_size,
+                        int blk_row, int blk_col, TxSetType tx_set_type,
+                        TX_TYPE_PRUNE_MODE prune_mode, int *txk_map,
+                        uint16_t *allowed_tx_mask) {
   int tx_type_table_2D[16] = {
     DCT_DCT,      DCT_ADST,      DCT_FLIPADST,      V_DCT,
     ADST_DCT,     ADST_ADST,     ADST_FLIPADST,     V_ADST,
@@ -1890,7 +1890,7 @@
   };
   if (tx_set_type != EXT_TX_SET_ALL16 &&
       tx_set_type != EXT_TX_SET_DTT9_IDTX_1DDCT)
-    return 0;
+    return;
 #if CONFIG_NN_V2
   NN_CONFIG_V2 *nn_config_hor = av1_tx_type_nnconfig_map_hor[tx_size];
   NN_CONFIG_V2 *nn_config_ver = av1_tx_type_nnconfig_map_ver[tx_size];
@@ -1898,7 +1898,7 @@
   const NN_CONFIG *nn_config_hor = av1_tx_type_nnconfig_map_hor[tx_size];
   const NN_CONFIG *nn_config_ver = av1_tx_type_nnconfig_map_ver[tx_size];
 #endif
-  if (!nn_config_hor || !nn_config_ver) return 0;  // Model not established yet.
+  if (!nn_config_hor || !nn_config_ver) return;  // Model not established yet.
 
   aom_clear_system_state();
   float hfeatures[16], vfeatures[16];
@@ -1940,7 +1940,7 @@
 
   av1_nn_softmax(scores_2D_raw, scores_2D, 16);
 
-  const int prune_aggr_table[3][2] = { { 4, 1 }, { 6, 3 }, { 9, 6 } };
+  const int prune_aggr_table[4][2] = { { 4, 1 }, { 6, 3 }, { 9, 6 }, { 9, 6 } };
   int pruning_aggressiveness = 0;
   if (tx_set_type == EXT_TX_SET_ALL16) {
     pruning_aggressiveness =
@@ -1956,7 +1956,7 @@
   float max_score = 0.0f;
   for (int i = 0; i < 16; i++) {
     if (scores_2D[i] > max_score &&
-        (allowed_tx_mask & (1 << tx_type_table_2D[i]))) {
+        (*allowed_tx_mask & (1 << tx_type_table_2D[i]))) {
       max_score = scores_2D[i];
       max_score_i = i;
     }
@@ -1965,16 +1965,53 @@
   const float score_thresh =
       prune_2D_adaptive_thresholds[tx_size][pruning_aggressiveness];
 
-  uint16_t prune_bitmask = 0;
-  for (int i = 0; i < 16; i++) {
-    if (scores_2D[i] < score_thresh && i != max_score_i)
-      prune_bitmask |= (1 << tx_type_table_2D[i]);
+  uint16_t allow_bitmask = 0;
+  float sum_score = 0.0;
+  // Calculate sum of allowed tx type score and Populate allow bit mask based
+  // on score_thresh and allowed_tx_mask
+  for (int tx_idx = 0; tx_idx < TX_TYPES; tx_idx++) {
+    int allow_tx_type = *allowed_tx_mask & (1 << tx_type_table_2D[tx_idx]);
+    if ((scores_2D[tx_idx] >= score_thresh && allow_tx_type) ||
+        tx_idx == max_score_i) {
+      // Set allow mask based on score_thresh and tx type with max score
+      allow_bitmask |= (1 << tx_type_table_2D[tx_idx]);
+
+      // Accumulate score of allowed tx type
+      sum_score += scores_2D[tx_idx];
+    }
   }
-
+  // Sort tx type probability of all types
   sort_probability(scores_2D, tx_type_table_2D, TX_TYPES);
-  memcpy(txk_map, tx_type_table_2D, sizeof(tx_type_table_2D));
 
-  return prune_bitmask;
+  // Enable more pruning based on tx type probability and number of allowed tx
+  // types
+  if (prune_mode == PRUNE_2D_AGGRESSIVE) {
+    float temp_score = 0.0;
+    float score_ratio = 0.0;
+    int tx_idx, tx_count = 0;
+    const float inv_sum_score = 100 / sum_score;
+    // Get allowed tx types based on sorted probability score and tx count
+    for (tx_idx = 0; tx_idx < TX_TYPES; tx_idx++) {
+      // Skip the tx type which has more than 30% of cumulative
+      // probability and allowed tx type count is more than 2
+      if (score_ratio > 30.0 && tx_count >= 2) break;
+
+      // Calculate cumulative probability of allowed tx types
+      if (allow_bitmask & (1 << tx_type_table_2D[tx_idx])) {
+        // Calculate cumulative probability
+        temp_score += scores_2D[tx_idx];
+
+        // Calculate percentage of cumulative probability of allowed tx type
+        score_ratio = temp_score * inv_sum_score;
+        tx_count++;
+      }
+    }
+    // Set remaining tx types as pruned
+    for (; tx_idx < TX_TYPES; tx_idx++)
+      allow_bitmask &= ~(1 << tx_type_table_2D[tx_idx]);
+  }
+  memcpy(txk_map, tx_type_table_2D, sizeof(tx_type_table_2D));
+  *allowed_tx_mask = allow_bitmask;
 }
 
 static AOM_INLINE void model_rd_from_sse(const AV1_COMP *const cpi,
@@ -3363,14 +3400,12 @@
     }
     assert(num_allowed > 0);
 
-    // Go through ML model only if num_allowed > 5.
+    int allowed_tx_count = (x->prune_mode == PRUNE_2D_AGGRESSIVE) ? 1 : 5;
     // !fast_tx_search && txk_end != txk_start && plane == 0
-    if (cpi->sf.tx_type_search.prune_mode >= PRUNE_2D_ACCURATE && is_inter &&
-        num_allowed > 5) {
-      const uint16_t prune = prune_tx_2D(
-          x, plane_bsize, tx_size, blk_row, blk_col, tx_set_type,
-          cpi->sf.tx_type_search.prune_mode, txk_map, allowed_tx_mask);
-      allowed_tx_mask &= (~prune);
+    if (x->prune_mode >= PRUNE_2D_ACCURATE && is_inter &&
+        num_allowed > allowed_tx_count) {
+      prune_tx_2D(x, plane_bsize, tx_size, blk_row, blk_col, tx_set_type,
+                  x->prune_mode, txk_map, &allowed_tx_mask);
     }
   }
 
diff --git a/av1/encoder/rdopt.h b/av1/encoder/rdopt.h
index dfb36f0..3bef55d 100644
--- a/av1/encoder/rdopt.h
+++ b/av1/encoder/rdopt.h
@@ -244,6 +244,19 @@
   x->tx_mode = select_tx_mode(cpi, x->tx_size_search_method);
 }
 
+static INLINE void set_tx_type_prune(const SPEED_FEATURES *sf, MACROBLOCK *x,
+                                     int enable_winner_mode_tx_type_pruning,
+                                     int is_winner_mode) {
+  // Populate prune transform mode appropriately
+  x->prune_mode = sf->tx_type_search.prune_mode;
+  if (enable_winner_mode_tx_type_pruning) {
+    if (is_winner_mode)
+      x->prune_mode = NO_PRUNE;
+    else
+      x->prune_mode = PRUNE_2D_AGGRESSIVE;
+  }
+}
+
 static INLINE void set_tx_domain_dist_params(
     const struct AV1_COMP *cpi, MACROBLOCK *x,
     int enable_winner_mode_for_tx_domain_dist, int is_winner_mode) {
@@ -313,6 +326,8 @@
           get_rd_opt_coeff_thresh(cpi->coeff_opt_dist_threshold, 0, 0);
       // Set default transform size search method
       set_tx_size_search_method(cpi, x, 0, 0);
+      // Set default transform type prune
+      set_tx_type_prune(sf, x, 0, 0);
       break;
     case MODE_EVAL:
       x->use_default_intra_tx_type =
@@ -334,6 +349,9 @@
       // Set the transform size search method for mode evaluation
       set_tx_size_search_method(cpi, x, sf->enable_winner_mode_for_tx_size_srch,
                                 0);
+      // Set transform type prune for mode evaluation
+      set_tx_type_prune(
+          sf, x, sf->tx_type_search.enable_winner_mode_tx_type_pruning, 0);
       break;
     case WINNER_MODE_EVAL:
       x->use_default_inter_tx_type = 0;
@@ -352,6 +370,9 @@
       // Set the transform size search method for winner mode evaluation
       set_tx_size_search_method(cpi, x, sf->enable_winner_mode_for_tx_size_srch,
                                 1);
+      // Set default transform type prune mode for winner mode evaluation
+      set_tx_type_prune(
+          sf, x, sf->tx_type_search.enable_winner_mode_tx_type_pruning, 1);
       break;
     default: assert(0);
   }
diff --git a/av1/encoder/speed_features.c b/av1/encoder/speed_features.c
index 2939876..1a8ea37 100644
--- a/av1/encoder/speed_features.c
+++ b/av1/encoder/speed_features.c
@@ -429,6 +429,7 @@
     // TODO(any): Extend multi-winner mode processing support for inter frames
     sf->enable_multiwinner_mode_process =
         frame_is_intra_only(&cpi->common) ? 1 : 0;
+    sf->tx_type_search.enable_winner_mode_tx_type_pruning = 1;
     // TODO(any): Experiment with this speed feature set to 2 for higher quality
     // presets as well
     sf->skip_intra_in_interframe = 2;
@@ -812,6 +813,7 @@
   sf->tx_type_search.fast_inter_tx_type_search = 0;
   sf->tx_type_search.skip_tx_search = 0;
   sf->tx_type_search.prune_tx_type_using_stats = 0;
+  sf->tx_type_search.enable_winner_mode_tx_type_pruning = 0;
   sf->selective_ref_frame = 0;
   sf->less_rectangular_check_level = 0;
   sf->use_square_partition_only_threshold = BLOCK_128X128;
diff --git a/av1/encoder/speed_features.h b/av1/encoder/speed_features.h
index 71222d0..2aaf859 100644
--- a/av1/encoder/speed_features.h
+++ b/av1/encoder/speed_features.h
@@ -163,6 +163,8 @@
   // similar, but applies much more aggressive pruning to get better speed-up
   PRUNE_2D_FAST = 2,
   PRUNE_2D_MORE = 3,
+  // More aggressive pruning based on tx type score and allowed tx count
+  PRUNE_2D_AGGRESSIVE = 4,
 } UENUM1BYTE(TX_TYPE_PRUNE_MODE);
 
 typedef struct {
@@ -186,6 +188,12 @@
 
   // Prune tx type search using previous frame stats.
   int prune_tx_type_using_stats;
+
+  // Flag used to control the winner mode processing for tx type pruning for
+  // inter blocks. It enables further tx type mode pruning based on ML model for
+  // mode evaluation and disables tx type mode pruning for winner mode
+  // processing.
+  int enable_winner_mode_tx_type_pruning;
 } TX_TYPE_SEARCH;
 
 enum {