Use surf-fit model for tx size/type pruning

Use a better surf fit model followed by tightening
the thresholds for pruning.
Also includes some refactoring.

About 2% faster with 0.02% loss (PSNR/SSIM)

STATS_CHANGED

Change-Id: I328bd13ad1aef6513690c13e605fe78212a2e32e
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 21c3af5..b423c7e 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -92,6 +92,20 @@
     int plane_from, int plane_to, int mi_row, int mi_col, int *out_rate_sum,
     int64_t *out_dist_sum, int *skip_txfm_sb, int64_t *skip_sse_sb,
     int *plane_rate, int64_t *plane_sse, int64_t *plane_dist);
+static void model_rd_with_dnn(const AV1_COMP *const cpi,
+                              const MACROBLOCK *const x, BLOCK_SIZE plane_bsize,
+                              int plane, int64_t sse, int num_samples,
+                              int *rate, int64_t *dist);
+static void model_rd_with_curvfit(const AV1_COMP *const cpi,
+                                  const MACROBLOCK *const x,
+                                  BLOCK_SIZE plane_bsize, int plane,
+                                  int64_t sse, int num_samples, int *rate,
+                                  int64_t *dist);
+static void model_rd_with_surffit(const AV1_COMP *const cpi,
+                                  const MACROBLOCK *const x,
+                                  BLOCK_SIZE plane_bsize, int plane,
+                                  int64_t sse, int num_samples, int *rate,
+                                  int64_t *dist);
 
 typedef enum {
   MODELRD_LEGACY,
@@ -113,7 +127,7 @@
 // 3: DNN regression model
 // 4: Full rd model
 #define MODELRD_TYPE_INTERP_FILTER 1
-#define MODELRD_TYPE_TX_SEARCH_PRUNE 1
+#define MODELRD_TYPE_TX_SEARCH_PRUNE 2
 
 #define DUAL_FILTER_SET_SIZE (SWITCHABLE_FILTERS * SWITCHABLE_FILTERS)
 static const InterpFilters filter_sets[DUAL_FILTER_SET_SIZE] = {
@@ -2508,22 +2522,22 @@
 #endif  // CONFIG_COLLECT_RD_STATS >= 2
 #endif  // CONFIG_COLLECT_RD_STATS
 
-static void model_rd_with_dnn(const AV1_COMP *const cpi, MACROBLOCK *const x,
-                              BLOCK_SIZE plane_bsize, int plane, int64_t sse,
+static void model_rd_with_dnn(const AV1_COMP *const cpi,
+                              const MACROBLOCK *const x, BLOCK_SIZE plane_bsize,
+                              int plane, int64_t sse, int num_samples,
                               int *rate, int64_t *dist) {
   const MACROBLOCKD *const xd = &x->e_mbd;
   const struct macroblockd_plane *const pd = &xd->plane[plane];
   const int log_numpels = num_pels_log2_lookup[plane_bsize];
 
-  int bw, bh;
-  get_txb_dimensions(xd, plane, plane_bsize, 0, 0, plane_bsize, NULL, NULL, &bw,
-                     &bh);
-  const int num_samples = bw * bh;
   const int dequant_shift =
       (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) ? xd->bd - 5 : 3;
   const int q_step = AOMMAX(pd->dequant_Q3[1] >> dequant_shift, 1);
 
   const struct macroblock_plane *const p = &x->plane[plane];
+  int bw, bh;
+  get_txb_dimensions(xd, plane, plane_bsize, 0, 0, plane_bsize, NULL, NULL, &bw,
+                     &bh);
   const int src_stride = p->src.stride;
   const uint8_t *const src = p->src.buf;
   const int dst_stride = pd->dst.stride;
@@ -2542,8 +2556,8 @@
   if (plane) {
     int model_rate;
     int64_t model_dist;
-    model_rd_from_sse(cpi, x, plane_bsize, plane, sse, &model_rate,
-                      &model_dist);
+    model_rd_with_curvfit(cpi, x, plane_bsize, plane, sse, num_samples,
+                          &model_rate, &model_dist);
     if (rate) *rate = model_rate;
     if (dist) *dist = model_dist;
     return;
@@ -2626,9 +2640,9 @@
 
     if (x->skip_chroma_rd && plane) continue;
 
-    int bw, bh;
     const struct macroblock_plane *const p = &x->plane[plane];
     const int shift = (xd->bd - 8);
+    int bw, bh;
     get_txb_dimensions(xd, plane, plane_bsize, 0, 0, plane_bsize, NULL, NULL,
                        &bw, &bh);
     if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
@@ -2640,7 +2654,7 @@
     }
     sse = ROUND_POWER_OF_TWO(sse, shift * 2);
 
-    model_rd_with_dnn(cpi, x, plane_bsize, plane, sse, &rate, &dist);
+    model_rd_with_dnn(cpi, x, plane_bsize, plane, sse, bw * bh, &rate, &dist);
 
     if (plane == 0) x->pred_sse[ref] = (unsigned int)AOMMIN(sse, UINT_MAX);
 
@@ -2663,19 +2677,17 @@
 // Fits a surface for rate and distortion using as features:
 // log2(sse_norm + 1) and log2(sse_norm/qstep^2)
 static void model_rd_with_surffit(const AV1_COMP *const cpi,
-                                  MACROBLOCK *const x, BLOCK_SIZE plane_bsize,
-                                  int plane, int64_t sse, int *rate,
+                                  const MACROBLOCK *const x,
+                                  BLOCK_SIZE plane_bsize, int plane,
+                                  int64_t sse, int num_samples, int *rate,
                                   int64_t *dist) {
   (void)cpi;
+  (void)plane_bsize;
   const MACROBLOCKD *const xd = &x->e_mbd;
   const struct macroblockd_plane *const pd = &xd->plane[plane];
   const int dequant_shift =
       (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) ? xd->bd - 5 : 3;
   const int qstep = AOMMAX(pd->dequant_Q3[1] >> dequant_shift, 1);
-  int bw, bh;
-  get_txb_dimensions(xd, plane, plane_bsize, 0, 0, plane_bsize, NULL, NULL, &bw,
-                     &bh);
-  const int num_samples = bw * bh;
   if (sse == 0) {
     if (rate) *rate = 0;
     if (dist) *dist = 0;
@@ -2746,7 +2758,8 @@
     }
     sse = ROUND_POWER_OF_TWO(sse, shift * 2);
 
-    model_rd_with_surffit(cpi, x, plane_bsize, plane, sse, &rate, &dist);
+    model_rd_with_surffit(cpi, x, plane_bsize, plane, sse, bw * bh, &rate,
+                          &dist);
 
     if (plane == 0) x->pred_sse[ref] = (unsigned int)AOMMIN(sse, UINT_MAX);
 
@@ -2769,19 +2782,17 @@
 // Fits a curve for rate and distortion using as feature:
 // log2(sse_norm/qstep^2)
 static void model_rd_with_curvfit(const AV1_COMP *const cpi,
-                                  MACROBLOCK *const x, BLOCK_SIZE plane_bsize,
-                                  int plane, int64_t sse, int *rate,
+                                  const MACROBLOCK *const x,
+                                  BLOCK_SIZE plane_bsize, int plane,
+                                  int64_t sse, int num_samples, int *rate,
                                   int64_t *dist) {
   (void)cpi;
+  (void)plane_bsize;
   const MACROBLOCKD *const xd = &x->e_mbd;
   const struct macroblockd_plane *const pd = &xd->plane[plane];
   const int dequant_shift =
       (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) ? xd->bd - 5 : 3;
   const int qstep = AOMMAX(pd->dequant_Q3[1] >> dequant_shift, 1);
-  int bw, bh;
-  get_txb_dimensions(xd, plane, plane_bsize, 0, 0, plane_bsize, NULL, NULL, &bw,
-                     &bh);
-  const int num_samples = bw * bh;
 
   if (sse == 0) {
     if (rate) *rate = 0;
@@ -2853,7 +2864,8 @@
     }
 
     sse = ROUND_POWER_OF_TWO(sse, shift * 2);
-    model_rd_with_curvfit(cpi, x, plane_bsize, plane, sse, &rate, &dist);
+    model_rd_with_curvfit(cpi, x, plane_bsize, plane, sse, bw * bh, &rate,
+                          &dist);
 
     if (plane == 0) x->pred_sse[ref] = (unsigned int)AOMMIN(sse, UINT_MAX);
 
@@ -2917,7 +2929,8 @@
         dist = rd_stats.dist;
       }
     } else {
-      model_rd_from_sse(cpi, x, plane_bsize, plane, sse, &rate, &dist);
+      model_rd_with_curvfit(cpi, x, plane_bsize, plane, sse, bw * bh, &rate,
+                            &dist);
     }
 
     if (plane == 0) x->pred_sse[ref] = (unsigned int)AOMMIN(sse, UINT_MAX);
@@ -5595,7 +5608,8 @@
     // tighter.
     assert(cpi->sf.model_based_prune_tx_search_level >= 0 &&
            cpi->sf.model_based_prune_tx_search_level <= 2);
-    static const int prune_factor_by8[] = { 3, 5 };
+    static const int prune_factor_by8[] = { 2 + MODELRD_TYPE_TX_SEARCH_PRUNE,
+                                            4 + MODELRD_TYPE_TX_SEARCH_PRUNE };
     if (!model_skip &&
         ((model_rd *
           prune_factor_by8[cpi->sf.model_based_prune_tx_search_level - 1]) >>
@@ -7416,6 +7430,7 @@
     sse = ROUND_POWER_OF_TWO(sse, bd_round);
 
     model_rd_from_sse(cpi, x, bsize, 0, sse, &rate, &dist);
+    // model_rd_with_curvfit(cpi, x, bsize, 0, sse, N, &rate, &dist);
     rate += x->wedge_idx_cost[bsize][wedge_index];
     rd = RDCOST(x->rdmult, rate, dist);
 
@@ -7461,6 +7476,7 @@
     sse = ROUND_POWER_OF_TWO(sse, bd_round);
 
     model_rd_from_sse(cpi, x, bsize, 0, sse, &rate, &dist);
+    // model_rd_with_curvfit(cpi, x, bsize, 0, sse, N, &rate, &dist);
     rate += x->wedge_idx_cost[bsize][wedge_index];
     rd = RDCOST(x->rdmult, rate, dist);
 
@@ -7539,6 +7555,7 @@
     sse = ROUND_POWER_OF_TWO(sse, bd_round);
 
     model_rd_from_sse(cpi, x, bsize, 0, sse, &rate, &dist);
+    // model_rd_with_curvfit(cpi, x, bsize, 0, sse, N, &rate, &dist);
     const int64_t rd0 = RDCOST(x->rdmult, rate, dist);
 
     if (rd0 < best_rd) {