CWG-E171 Intra mode search adjustment

1. Allow more full transform searchs for intra
by relaxing pruning:
(1). Allow 6 instead of 4 best modes.
(2). Remove the best_model_rd constraint

2. Prune tx partition search

Record the rdcost of none partition for each intra prediction mode.
Keep the top 4 and terminate more tx partition search if the none
tx partition rdcost is already larger than the 4th best rdcost.

STATS_CHANGED

Change-Id: Iab9875469a5a82c9a5d97d4e82bb415491fd956d
diff --git a/av1/common/enums.h b/av1/common/enums.h
index d3617d9..0935a00 100644
--- a/av1/common/enums.h
+++ b/av1/common/enums.h
@@ -916,7 +916,8 @@
 } UENUM1BYTE(CFL_TYPE);
 
 // Number of top model rd to store for pruning y modes in intra mode decision
-#define TOP_INTRA_MODEL_COUNT 4
+#define TOP_INTRA_MODEL_COUNT 6
+#define TOP_TX_PART_COUNT 4
 // Total number of luma intra prediction modes (include both directional and
 // non-directional modes)
 #define LUMA_MODE_COUNT 61
diff --git a/av1/encoder/block.h b/av1/encoder/block.h
index 37a2e18..5711070 100644
--- a/av1/encoder/block.h
+++ b/av1/encoder/block.h
@@ -1905,6 +1905,10 @@
    */
   int palette_pixels;
 #endif  // CONFIG_SCC_DETERMINATION
+  /*! \brief Whether to prune current transform partition search. */
+  int prune_tx_partition;
+  /*! \brief Keep records of top rdcosts of transform partition search. */
+  int64_t top_tx_part_rd[TOP_TX_PART_COUNT];
 } MACROBLOCK;
 #undef SINGLE_REF_MODES
 
diff --git a/av1/encoder/intra_mode_search.c b/av1/encoder/intra_mode_search.c
index 6ad0c6d..449dec4 100644
--- a/av1/encoder/intra_mode_search.c
+++ b/av1/encoder/intra_mode_search.c
@@ -82,6 +82,7 @@
     if (model_intra_yrd_and_prune(cpi, x, bsize, mode_cost, best_model_rd)) {
       continue;
     }
+    x->prune_tx_partition = 0;
     av1_pick_uniform_tx_size_type_yrd(cpi, x, &tokenonly_rd_stats, bsize,
                                       *best_rd);
     if (tokenonly_rd_stats.rate == INT_MAX) continue;
@@ -225,7 +226,6 @@
  */
 int prune_intra_y_mode(int64_t this_model_rd, int64_t *best_model_rd,
                        int64_t top_intra_model_rd[]) {
-  const double thresh_best = 1.50;
   const double thresh_top = 1.00;
   for (int i = 0; i < TOP_INTRA_MODEL_COUNT; i++) {
     if (this_model_rd < top_intra_model_rd[i]) {
@@ -241,9 +241,6 @@
           thresh_top * top_intra_model_rd[TOP_INTRA_MODEL_COUNT - 1])
     return 1;
 
-  if (this_model_rd != INT64_MAX &&
-      this_model_rd > thresh_best * (*best_model_rd))
-    return 1;
   if (this_model_rd < *best_model_rd) *best_model_rd = this_model_rd;
   return 0;
 }
@@ -1093,6 +1090,7 @@
   MACROBLOCKD *const xd = &x->e_mbd;
   MB_MODE_INFO *const mbmi = xd->mi[0];
   RD_STATS rd_stats;
+  x->prune_tx_partition = 0;
   // In order to improve txfm search avoid rd based breakouts during winner
   // mode evaluation. Hence passing ref_best_rd as a maximum value
   av1_pick_uniform_tx_size_type_yrd(cpi, x, &rd_stats, bsize, INT64_MAX);
@@ -1171,6 +1169,7 @@
   for (FILTER_INTRA_MODE fi_mode = FILTER_DC_PRED; fi_mode < FILTER_INTRA_MODES;
        ++fi_mode) {
     mbmi->filter_intra_mode_info.filter_intra_mode = fi_mode;
+    x->prune_tx_partition = 0;
     av1_pick_uniform_tx_size_type_yrd(cpi, x, &rd_stats_y_fi, bsize, best_rd);
     if (rd_stats_y_fi.rate == INT_MAX) continue;
     const int this_rate_tmp =
@@ -1374,6 +1373,7 @@
   )
     return INT64_MAX;
   av1_init_rd_stats(rd_stats_y);
+  x->prune_tx_partition = 0;
   av1_pick_uniform_tx_size_type_yrd(cpi, x, rd_stats_y, bsize, best_rd);
 
   // Pick filter intra modes.
@@ -1602,6 +1602,14 @@
   if (xd->lossless[mbmi->segment_id]) {
     dpcm_fsc_loop = 2;
   }
+  int64_t top_intra_model_rd[TOP_INTRA_MODEL_COUNT];
+  for (int i = 0; i < TOP_INTRA_MODEL_COUNT; i++) {
+    top_intra_model_rd[i] = INT64_MAX;
+  }
+  x->prune_tx_partition = 1;
+  for (int i = 0; i < TOP_TX_PART_COUNT; i++) {
+    x->top_tx_part_rd[i] = INT64_MAX;
+  }
   for (int dpcm_fsc_index = 0; dpcm_fsc_index < dpcm_fsc_loop;
        dpcm_fsc_index++) {
     mbmi->use_dpcm_y = dpcm_fsc_index;
@@ -1741,19 +1749,20 @@
 #if CONFIG_AIMC
         mode_costs += mrl_idx_cost;
 #endif  // CONFIG_AIMC
-        if (model_intra_yrd_and_prune(cpi, x, bsize,
+        int64_t this_model_rd;
+        this_model_rd = intra_model_yrd(cpi, x, bsize,
 #if CONFIG_AIMC
-                                      mode_costs,
+                                        mode_costs);
 #else
-                                    mode_costs[mbmi->mode] + mrl_idx_cost,
-#endif
-                                      best_model_rd)
+                                      mode_costs[mbmi->mode] + mrl_idx_cost);
+#endif  // CONFIG_AIMC
+
+        if (prune_intra_y_mode(this_model_rd, best_model_rd, top_intra_model_rd)
 #if CONFIG_LOSSLESS_DPCM
             && (!xd->lossless[mbmi->segment_id] || mbmi->use_dpcm_y == 0)
 #endif  // CONFIG_LOSSLESS_DPCM
-        ) {
+        )
           continue;
-        }
         av1_pick_uniform_tx_size_type_yrd(cpi, x, &tokenonly_rd_stats, bsize,
                                           *best_rd);
         if (tokenonly_rd_stats.rate == INT_MAX) continue;
@@ -1909,6 +1918,10 @@
   for (int i = 0; i < TOP_INTRA_MODEL_COUNT; i++) {
     top_intra_model_rd[i] = INT64_MAX;
   }
+  x->prune_tx_partition = 1;
+  for (int i = 0; i < TOP_TX_PART_COUNT; i++) {
+    x->top_tx_part_rd[i] = INT64_MAX;
+  }
   uint8_t enable_mrls_flag = cpi->common.seq_params.enable_mrls;
 #if CONFIG_LOSSLESS_DPCM
   int dpcm_loop_num = 1;
diff --git a/av1/encoder/tx_search.c b/av1/encoder/tx_search.c
index 874b7cf..96d3ecf 100644
--- a/av1/encoder/tx_search.c
+++ b/av1/encoder/tx_search.c
@@ -4063,6 +4063,20 @@
       *rd_stats = this_rd_stats;
     }
     if (cur_tx_size == TX_4X4) break;
+    if (x->prune_tx_partition && type == 0) {
+      for (int i = 0; i < TOP_TX_PART_COUNT; i++) {
+        if (cur_rd < x->top_tx_part_rd[i]) {
+          for (int j = TOP_TX_PART_COUNT - 1; j > i; j--) {
+            x->top_tx_part_rd[j] = x->top_tx_part_rd[j - 1];
+          }
+          x->top_tx_part_rd[i] = cur_rd;
+          break;
+        }
+      }
+      if (x->top_tx_part_rd[TOP_TX_PART_COUNT - 1] != INT64_MAX &&
+          cur_rd > x->top_tx_part_rd[TOP_TX_PART_COUNT - 1])
+        break;
+    }
   }
 
   if (rd_stats->rate != INT_MAX) {