Tpl model based reference frame selection

Use tpl model stats to validate the hypothesis on reference frame
selection. This improves the compression performance. Tested in vbr
mode over 150 frames.

In speed 1,
      avg PSNR    overall PSNR    SSIM
low    -0.186	    -0.187	-0.155
mid    -0.026	    -0.108	-0.063
ugc    -0.104  	    -0.091	-0.124

In speed 2,
      avg PSNR    overall PSNR    SSIM
low    -1.369	    -1.422	-1.809
mid    -1.394	    -1.433	-1.761
ugc    -1.001	    -0.979	-1.063

The average CPU cycles in speed 2 increased by 5%.

STATS_CHANGED

Change-Id: I350ebabaf3974ae7c47ed9b4272485ebce85ac32
diff --git a/av1/encoder/block.h b/av1/encoder/block.h
index 3f93a8e..a8cc8ff 100644
--- a/av1/encoder/block.h
+++ b/av1/encoder/block.h
@@ -472,6 +472,8 @@
   // The type of mv cost used during motion search
   MV_COST_TYPE mv_cost_type;
 
+  int search_ref_frame[REF_FRAMES];
+
 #if CONFIG_AV1_HIGHBITDEPTH
   void (*fwd_txfm4x4)(const int16_t *input, tran_low_t *output, int stride);
   void (*inv_txfm_add)(const tran_low_t *input, uint8_t *dest, int stride,
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index 9436c37..687d688 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -4388,6 +4388,75 @@
 #endif  // CONFIG_INTERNAL_STATS
 }
 
+#if !CONFIG_REALTIME_ONLY
+static void init_ref_frame_space(AV1_COMP *cpi, ThreadData *td, int mi_row,
+                                 int mi_col) {
+  const AV1_COMMON *cm = &cpi->common;
+  MACROBLOCK *x = &td->mb;
+  const int frame_idx = cpi->gf_group.index;
+  TplDepFrame *tpl_frame = &cpi->tpl_frame[frame_idx];
+
+  memset(x->search_ref_frame, 0, sizeof(x->search_ref_frame));
+
+  if (tpl_frame->is_valid == 0) return;
+  if (!is_frame_tpl_eligible(cpi)) return;
+  if (frame_idx >= MAX_LAG_BUFFERS) return;
+  if (cpi->oxcf.superres_mode != SUPERRES_NONE) return;
+  if (cpi->oxcf.aq_mode != NO_AQ) return;
+
+  TplDepStats *tpl_stats = tpl_frame->tpl_stats_ptr;
+  const int tpl_stride = tpl_frame->stride;
+  int64_t inter_cost[INTER_REFS_PER_FRAME] = { 0 };
+  const int step = 1 << cpi->tpl_stats_block_mis_log2;
+
+  const BLOCK_SIZE sb_size = cm->seq_params.sb_size;
+  const int mi_row_end = AOMMIN(mi_size_high[sb_size] + mi_row, cm->mi_rows);
+  const int mi_col_end = AOMMIN(mi_size_wide[sb_size] + mi_col, cm->mi_cols);
+
+  for (int row = mi_row; row < mi_row_end; row += step) {
+    for (int col = mi_col; col < mi_col_end; col += step) {
+      TplDepStats *this_stats =
+          &tpl_stats[av1_tpl_ptr_pos(cpi, row, col, tpl_stride)];
+      for (int rf_idx = 0; rf_idx < INTER_REFS_PER_FRAME; ++rf_idx)
+        inter_cost[rf_idx] += this_stats->pred_error[rf_idx];
+    }
+  }
+
+  int rank_index[INTER_REFS_PER_FRAME - 1] = { 0 };
+
+  for (int idx = 0; idx < INTER_REFS_PER_FRAME - 1; ++idx) {
+    rank_index[idx] = idx + 1;
+    for (int i = idx; i > 0; --i) {
+      if (inter_cost[rank_index[i - 1]] > inter_cost[rank_index[i]]) {
+        const int tmp = rank_index[i - 1];
+        rank_index[i - 1] = rank_index[i];
+        rank_index[i] = tmp;
+      }
+    }
+  }
+
+  const int is_overlay = cpi->gf_group.update_type[frame_idx] == OVERLAY_UPDATE;
+  x->search_ref_frame[INTRA_FRAME] = 1;
+  x->search_ref_frame[LAST_FRAME] = 1;
+
+  int cutoff_ref = 0;
+
+  for (int idx = 0; idx < INTER_REFS_PER_FRAME - 1; ++idx) {
+    x->search_ref_frame[rank_index[idx] + LAST_FRAME] = 1;
+    if (idx > 2 && !is_overlay) {
+      // If the predictive coding gains are smaller than the previous more
+      // relevant frame over certain amount, discard this frame.
+      if (labs(inter_cost[rank_index[idx]]) <
+              labs(inter_cost[rank_index[idx - 1]]) / 8 ||
+          inter_cost[rank_index[idx]] == 0)
+        cutoff_ref = 1;
+
+      if (cutoff_ref) x->search_ref_frame[rank_index[idx] + LAST_FRAME] = 0;
+    }
+  }
+}
+#endif  // !CONFIG_REALTIME_ONLY
+
 // This function initializes the stats for encode_rd_sb.
 static INLINE void init_encode_rd_sb(AV1_COMP *cpi, ThreadData *td,
                                      const TileDataEnc *tile_data,
@@ -4410,6 +4479,7 @@
   }
 
 #if !CONFIG_REALTIME_ONLY
+  init_ref_frame_space(cpi, td, mi_row, mi_col);
   x->sb_energy_level = 0;
   x->cnn_output_valid = 0;
   if (gather_tpl_data) {
@@ -5202,7 +5272,7 @@
         ref_buf[frame]->y_crop_height == cpi->source->y_crop_height &&
         do_gm_search_logic(&cpi->sf, frame) &&
         !prune_ref_by_selective_ref_frame(
-            cpi, ref_frame, cm->cur_frame->ref_display_order_hint,
+            cpi, NULL, ref_frame, cm->cur_frame->ref_display_order_hint,
             cm->current_frame.display_order_hint) &&
         !(cpi->sf.gm_sf.selective_ref_gm && skip_gm_frame(cm, frame))) {
       assert(ref_buf[frame] != NULL);
diff --git a/av1/encoder/encoder.h b/av1/encoder/encoder.h
index 86fd720..7b6e685 100644
--- a/av1/encoder/encoder.h
+++ b/av1/encoder/encoder.h
@@ -163,6 +163,7 @@
   int64_t mc_dep_dist;
   int_mv mv[INTER_REFS_PER_FRAME];
   int ref_frame_index;
+  int64_t pred_error[INTER_REFS_PER_FRAME];
 #if !USE_TPL_CLASSIC_MODEL
   int64_t mc_count;
   int64_t mc_saved;
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 994b9cf..4490591 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -3557,7 +3557,7 @@
     }
   }
 
-  if (prune_ref_by_selective_ref_frame(cpi, ref_frame,
+  if (prune_ref_by_selective_ref_frame(cpi, x, ref_frame,
                                        cm->cur_frame->ref_display_order_hint,
                                        cm->current_frame.display_order_hint))
     return 1;
diff --git a/av1/encoder/rdopt.h b/av1/encoder/rdopt.h
index 3a59ea4..22f83ff 100644
--- a/av1/encoder/rdopt.h
+++ b/av1/encoder/rdopt.h
@@ -159,6 +159,8 @@
                             const unsigned int frame_display_order_hint,
                             const int *ref_frame_list) {
   for (int i = 0; i < 2; i++) {
+    if (ref_frame_list[i] == NONE_FRAME) continue;
+
     if (ref_frame[0] == ref_frame_list[i] ||
         ref_frame[1] == ref_frame_list[i]) {
       if (av1_encoder_get_relative_dist(
@@ -172,10 +174,12 @@
 }
 
 static INLINE int prune_ref_by_selective_ref_frame(
-    const AV1_COMP *const cpi, const MV_REFERENCE_FRAME *const ref_frame,
+    const AV1_COMP *const cpi, const MACROBLOCK *const x,
+    const MV_REFERENCE_FRAME *const ref_frame,
     const unsigned int *const ref_display_order_hint,
     const unsigned int cur_frame_display_order_hint) {
   const SPEED_FEATURES *const sf = &cpi->sf;
+
   if (sf->inter_sf.selective_ref_frame) {
     const AV1_COMMON *const cm = &cpi->common;
     const OrderHintInfo *const order_hint_info =
@@ -183,7 +187,13 @@
     const int comp_pred = ref_frame[1] > INTRA_FRAME;
     if (sf->inter_sf.selective_ref_frame >= 2 ||
         (sf->inter_sf.selective_ref_frame == 1 && comp_pred)) {
-      const int ref_frame_list[2] = { LAST3_FRAME, LAST2_FRAME };
+      int ref_frame_list[2] = { LAST3_FRAME, LAST2_FRAME };
+
+      if (x != NULL) {
+        if (x->search_ref_frame[LAST3_FRAME]) ref_frame_list[0] = NONE_FRAME;
+        if (x->search_ref_frame[LAST2_FRAME]) ref_frame_list[1] = NONE_FRAME;
+      }
+
       if (prune_ref(ref_frame, order_hint_info, ref_display_order_hint,
                     ref_display_order_hint[GOLDEN_FRAME - LAST_FRAME],
                     ref_frame_list))
@@ -210,9 +220,16 @@
     }
 
     if (sf->inter_sf.selective_ref_frame >= 3) {
-      static const int ref_frame_list[2] = { ALTREF2_FRAME, BWDREF_FRAME };
+      int ref_frame_list[2] = { ALTREF2_FRAME, BWDREF_FRAME };
+
+      if (x != NULL) {
+        if (x->search_ref_frame[ALTREF2_FRAME]) ref_frame_list[0] = NONE_FRAME;
+        if (x->search_ref_frame[BWDREF_FRAME]) ref_frame_list[1] = NONE_FRAME;
+      }
+
       if (prune_ref(ref_frame, order_hint_info, ref_display_order_hint,
-                    cur_frame_display_order_hint, ref_frame_list))
+                    ref_display_order_hint[LAST_FRAME - LAST_FRAME],
+                    ref_frame_list))
         return 1;
     }
 
diff --git a/av1/encoder/tpl_model.c b/av1/encoder/tpl_model.c
index 3c244f7..7917213 100644
--- a/av1/encoder/tpl_model.c
+++ b/av1/encoder/tpl_model.c
@@ -382,6 +382,9 @@
     if (inter_cost < best_inter_cost) {
       memcpy(best_coeff, coeff, sizeof(best_coeff));
       best_rf_idx = rf_idx;
+
+      if (rf_idx == 0) tpl_stats->pred_error[rf_idx] = inter_cost;
+
       best_inter_cost = inter_cost;
       best_mv.as_int = x->best_mv.as_int;
       if (best_inter_cost < best_intra_cost) {
@@ -392,6 +395,11 @@
     }
   }
 
+  if (best_rf_idx >= 0) {
+    tpl_stats->pred_error[best_rf_idx] =
+        best_inter_cost - tpl_stats->pred_error[0];
+  }
+
   if (best_inter_cost < INT64_MAX) {
     uint16_t eob;
     get_quantize_error(x, 0, best_coeff, qcoeff, dqcoeff, tx_size, &eob,
@@ -657,6 +665,8 @@
       tpl_ptr->srcrf_rate = srcrf_rate;
       tpl_ptr->recrf_rate = recrf_rate;
       memcpy(tpl_ptr->mv, src_stats->mv, sizeof(tpl_ptr->mv));
+      memcpy(tpl_ptr->pred_error, src_stats->pred_error,
+             sizeof(tpl_ptr->pred_error));
       tpl_ptr->ref_frame_index = src_stats->ref_frame_index;
       ++tpl_ptr;
     }
@@ -672,8 +682,6 @@
   const YV12_BUFFER_CONFIG *ref_frame[7] = { NULL, NULL, NULL, NULL,
                                              NULL, NULL, NULL };
   const YV12_BUFFER_CONFIG *ref_frames_ordered[INTER_REFS_PER_FRAME];
-  unsigned int ref_frame_display_index[7];
-  MV_REFERENCE_FRAME ref[2] = { LAST_FRAME, INTRA_FRAME };
   int ref_frame_flags;
   const YV12_BUFFER_CONFIG *src_frame[7] = { NULL, NULL, NULL, NULL,
                                              NULL, NULL, NULL };
@@ -700,9 +708,7 @@
   xd->cur_buf = this_frame;
 
   for (idx = 0; idx < INTER_REFS_PER_FRAME; ++idx) {
-    TplDepFrame *tpl_ref_frame = &cpi->tpl_frame[tpl_frame->ref_map_index[idx]];
     ref_frame[idx] = cpi->tpl_frame[tpl_frame->ref_map_index[idx]].rec_picture;
-    ref_frame_display_index[idx] = tpl_ref_frame->frame_display_index;
     src_frame[idx] = cpi->tpl_frame[tpl_frame->ref_map_index[idx]].gf_picture;
   }
 
@@ -723,16 +729,6 @@
     }
   }
 
-  // Skip motion estimation w.r.t. reference frames which are not
-  // considered in RD search, using "selective_ref_frame" speed feature
-  for (idx = 0; idx < INTER_REFS_PER_FRAME; ++idx) {
-    ref[0] = idx + 1;
-    if (prune_ref_by_selective_ref_frame(cpi, ref, ref_frame_display_index,
-                                         tpl_frame->frame_display_index)) {
-      ref_frame[idx] = NULL;
-    }
-  }
-
   // Make a temporary mbmi for tpl model
   MB_MODE_INFO mbmi;
   memset(&mbmi, 0, sizeof(mbmi));