Add a new struct that holds sb level info on the encoder

This CL adds SuperBlockEnc struct that is meant to hold the superblock
level information in the rdopt process. For example, CNN-based
partitioning buffers can go in there.

Currently only tpl-based speed feature structs are added to
SuperBlockEnc.

BUG=aomedia:2618

Change-Id: I519258cd0bf1e3dd4707cf8396d67853d94bc7db
diff --git a/av1/encoder/block.h b/av1/encoder/block.h
index c88603b..6252200 100644
--- a/av1/encoder/block.h
+++ b/av1/encoder/block.h
@@ -32,6 +32,20 @@
 #define MAX_MC_FLOW_BLK_IN_SB (MAX_SB_SIZE / MC_FLOW_BSIZE_1D)
 #define MAX_WINNER_MODE_COUNT_INTRA 3
 #define MAX_WINNER_MODE_COUNT_INTER 1
+
+// SuperblockEnc stores superblock level information used by the encoder for
+// more efficient encoding.
+typedef struct {
+  // Below are information gathered from tpl_model used to speed up the encoding
+  // process.
+  int tpl_data_count;
+  int64_t tpl_inter_cost[MAX_MC_FLOW_BLK_IN_SB * MAX_MC_FLOW_BLK_IN_SB];
+  int64_t tpl_intra_cost[MAX_MC_FLOW_BLK_IN_SB * MAX_MC_FLOW_BLK_IN_SB];
+  int_mv tpl_mv[MAX_MC_FLOW_BLK_IN_SB * MAX_MC_FLOW_BLK_IN_SB]
+               [INTER_REFS_PER_FRAME];
+  int tpl_stride;
+} SuperBlockEnc;
+
 typedef struct {
   MB_MODE_INFO mbmi;
   RD_STATS rd_cost;
@@ -467,16 +481,12 @@
   // (normal/winner mode)
   unsigned int predict_skip_level;
 
-  // Copy out this SB's TPL block stats.
-  int valid_cost_b;
-  int64_t inter_cost_b[MAX_MC_FLOW_BLK_IN_SB * MAX_MC_FLOW_BLK_IN_SB];
-  int64_t intra_cost_b[MAX_MC_FLOW_BLK_IN_SB * MAX_MC_FLOW_BLK_IN_SB];
-  int_mv mv_b[MAX_MC_FLOW_BLK_IN_SB * MAX_MC_FLOW_BLK_IN_SB]
-             [INTER_REFS_PER_FRAME];
-  int cost_stride;
-
   uint8_t search_ref_frame[REF_FRAMES];
 
+  // The information on a whole superblock level.
+  // TODO(chiyotsai@google.com): Refactor this out of macroblock
+  SuperBlockEnc sb_enc;
+
 #if CONFIG_AV1_HIGHBITDEPTH
   void (*fwd_txfm4x4)(const int16_t *input, tran_low_t *output, int stride);
   void (*inv_txfm_add)(const tran_low_t *input, uint8_t *dest, int stride,
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index a13b745..558e757 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -4070,16 +4070,16 @@
   return rdmult;
 }
 
-static int get_tpl_stats_b(AV1_COMP *cpi, BLOCK_SIZE bsize, int mi_row,
-                           int mi_col, int64_t *intra_cost_b,
-                           int64_t *inter_cost_b,
-                           int_mv mv_b[][INTER_REFS_PER_FRAME], int *stride) {
-  if (!cpi->oxcf.enable_tpl_model) return 0;
-  if (cpi->superres_mode != AOM_SUPERRES_NONE) return 0;
-  if (cpi->common.current_frame.frame_type == KEY_FRAME) return 0;
+static void get_tpl_stats_sb(AV1_COMP *cpi, BLOCK_SIZE bsize, int mi_row,
+                             int mi_col, SuperBlockEnc *sb_enc) {
+  sb_enc->tpl_data_count = 0;
+
+  if (!cpi->oxcf.enable_tpl_model) return;
+  if (cpi->superres_mode != AOM_SUPERRES_NONE) return;
+  if (cpi->common.current_frame.frame_type == KEY_FRAME) return;
   const FRAME_UPDATE_TYPE update_type = get_frame_update_type(&cpi->gf_group);
   if (update_type == INTNL_OVERLAY_UPDATE || update_type == OVERLAY_UPDATE)
-    return 0;
+    return;
   assert(IMPLIES(cpi->gf_group.size > 0,
                  cpi->gf_group.index < cpi->gf_group.size));
 
@@ -4092,8 +4092,8 @@
   const int mi_wide = mi_size_wide[bsize];
   const int mi_high = mi_size_high[bsize];
 
-  if (tpl_frame->is_valid == 0) return 0;
-  if (gf_group_index >= MAX_TPL_FRAME_IDX) return 0;
+  if (tpl_frame->is_valid == 0) return;
+  if (gf_group_index >= MAX_TPL_FRAME_IDX) return;
 
   int mi_count = 0;
   int count = 0;
@@ -4113,16 +4113,16 @@
 
   // Stride is only based on SB size, and we fill in values for every 16x16
   // block in a SB.
-  *stride = (mi_col_end_sr - mi_col_sr) / step;
+  sb_enc->tpl_stride = (mi_col_end_sr - mi_col_sr) / step;
 
   for (int row = mi_row; row < mi_row + mi_high; row += step) {
     for (int col = mi_col_sr; col < mi_col_end_sr; col += step) {
       // Handle partial SB, so that no invalid values are used later.
       if (row >= cm->mi_params.mi_rows || col >= mi_cols_sr) {
-        inter_cost_b[count] = INT64_MAX;
-        intra_cost_b[count] = INT64_MAX;
+        sb_enc->tpl_inter_cost[count] = INT64_MAX;
+        sb_enc->tpl_intra_cost[count] = INT64_MAX;
         for (int i = 0; i < INTER_REFS_PER_FRAME; ++i) {
-          mv_b[count][i].as_int = INVALID_MV;
+          sb_enc->tpl_mv[count][i].as_int = INVALID_MV;
         }
         count++;
         continue;
@@ -4130,15 +4130,15 @@
 
       TplDepStats *this_stats = &tpl_stats[av1_tpl_ptr_pos(
           row, col, tpl_stride, tpl_data->tpl_stats_block_mis_log2)];
-      inter_cost_b[count] = this_stats->inter_cost;
-      intra_cost_b[count] = this_stats->intra_cost;
-      memcpy(mv_b[count], this_stats->mv, sizeof(this_stats->mv));
+      sb_enc->tpl_inter_cost[count] = this_stats->inter_cost;
+      sb_enc->tpl_intra_cost[count] = this_stats->intra_cost;
+      memcpy(sb_enc->tpl_mv[count], this_stats->mv, sizeof(this_stats->mv));
       mi_count++;
       count++;
     }
   }
 
-  return mi_count;
+  sb_enc->tpl_data_count = mi_count;
 }
 
 // analysis_type 0: Use mc_dep_cost and intra_cost
@@ -4871,10 +4871,9 @@
                      &dummy_rate, &dummy_dist, 1, pc_root);
     av1_free_pc_tree_recursive(pc_root, num_planes, 0, 0);
   } else {
+    SuperBlockEnc *sb_enc = &x->sb_enc;
     // No stats for overlay frames. Exclude key frame.
-    x->valid_cost_b =
-        get_tpl_stats_b(cpi, sb_size, mi_row, mi_col, x->intra_cost_b,
-                        x->inter_cost_b, x->mv_b, &x->cost_stride);
+    get_tpl_stats_sb(cpi, sb_size, mi_row, mi_col, sb_enc);
 
     reset_simple_motion_tree_partition(sms_root, sb_size);
 
@@ -4923,7 +4922,7 @@
                         pc_root_p1, sms_root, NULL, SB_WET_PASS, NULL);
     }
     // Reset to 0 so that it wouldn't be used elsewhere mistakenly.
-    x->valid_cost_b = 0;
+    sb_enc->tpl_data_count = 0;
 #if CONFIG_COLLECT_COMPONENT_TIMING
     end_timing(cpi, rd_pick_partition_time);
 #endif
diff --git a/av1/encoder/motion_search_facade.c b/av1/encoder/motion_search_facade.c
index 0a49bc1..f708984 100644
--- a/av1/encoder/motion_search_facade.c
+++ b/av1/encoder/motion_search_facade.c
@@ -142,7 +142,8 @@
 
   if (!cpi->sf.mv_sf.full_pixel_search_level &&
       mbmi->motion_mode == SIMPLE_TRANSLATION) {
-    if (x->valid_cost_b) {
+    SuperBlockEnc *sb_enc = &x->sb_enc;
+    if (sb_enc->tpl_data_count) {
       const BLOCK_SIZE tpl_bsize = convert_length_to_bsize(MC_FLOW_BSIZE_1D);
       const int tplw = mi_size_wide[tpl_bsize];
       const int tplh = mi_size_high[tpl_bsize];
@@ -152,7 +153,7 @@
       if (nw >= 1 && nh >= 1) {
         const int of_h = mi_row % mi_size_high[cm->seq_params.sb_size];
         const int of_w = mi_col % mi_size_wide[cm->seq_params.sb_size];
-        const int start = of_h / tplh * x->cost_stride + of_w / tplw;
+        const int start = of_h / tplh * sb_enc->tpl_stride + of_w / tplw;
         int valid = 1;
 
         // Assign large weight to start_mv, so it is always tested.
@@ -160,8 +161,8 @@
 
         for (int k = 0; k < nh; k++) {
           for (int l = 0; l < nw; l++) {
-            const int_mv mv =
-                x->mv_b[start + k * x->cost_stride + l][ref - LAST_FRAME];
+            const int_mv mv = sb_enc->tpl_mv[start + k * sb_enc->tpl_stride + l]
+                                            [ref - LAST_FRAME];
             if (mv.as_int == INVALID_MV) {
               valid = 0;
               break;
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index c93e854..0576e59 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -4603,8 +4603,10 @@
   if (do_pruning && sf->intra_sf.skip_intra_in_interframe) {
     // Only consider full SB.
     int len = tpl_blocks_in_sb(cm->seq_params.sb_size);
-    if (len == x->valid_cost_b) {
+    SuperBlockEnc *sb_enc = &x->sb_enc;
+    if (sb_enc->tpl_data_count == len) {
       const BLOCK_SIZE tpl_bsize = convert_length_to_bsize(MC_FLOW_BSIZE_1D);
+      const int tpl_stride = sb_enc->tpl_stride;
       const int tplw = mi_size_wide[tpl_bsize];
       const int tplh = mi_size_high[tpl_bsize];
       const int nw = mi_size_wide[bsize] / tplw;
@@ -4612,12 +4614,12 @@
       if (nw >= 1 && nh >= 1) {
         const int of_h = mi_row % mi_size_high[cm->seq_params.sb_size];
         const int of_w = mi_col % mi_size_wide[cm->seq_params.sb_size];
-        const int start = of_h / tplh * x->cost_stride + of_w / tplw;
+        const int start = of_h / tplh * tpl_stride + of_w / tplw;
 
         for (int k = 0; k < nh; k++) {
           for (int l = 0; l < nw; l++) {
-            inter_cost += x->inter_cost_b[start + k * x->cost_stride + l];
-            intra_cost += x->intra_cost_b[start + k * x->cost_stride + l];
+            inter_cost += sb_enc->tpl_inter_cost[start + k * tpl_stride + l];
+            intra_cost += sb_enc->tpl_intra_cost[start + k * tpl_stride + l];
           }
         }
         inter_cost /= nw * nh;