Refactor reference frame pruning logic in tpl

The reference frame removal logic of tpl and encoding
stage are unified.

Change-Id: If735e307a9f0d6327224ba513a61b05b6338bfc4
diff --git a/av1/encoder/encode_strategy.c b/av1/encoder/encode_strategy.c
index ce785f3..cb092ea 100644
--- a/av1/encoder/encode_strategy.c
+++ b/av1/encoder/encode_strategy.c
@@ -210,43 +210,6 @@
   frame_params->error_resilient_mode |= frame_params->frame_type == S_FRAME;
 }
 
-static int get_ref_frame_flags(const AV1_COMP *const cpi) {
-  static const MV_REFERENCE_FRAME
-      ref_frame_priority_order[INTER_REFS_PER_FRAME] = {
-        LAST_FRAME,    ALTREF_FRAME, BWDREF_FRAME, GOLDEN_FRAME,
-        ALTREF2_FRAME, LAST2_FRAME,  LAST3_FRAME,
-      };
-  const AV1_COMMON *const cm = &cpi->common;
-  const RefCntBuffer *ref_frames[INTER_REFS_PER_FRAME];
-  for (int i = 0; i < INTER_REFS_PER_FRAME; ++i) {
-    ref_frames[i] = get_ref_frame_buf(cm, ref_frame_priority_order[i]);
-  }
-
-  // cpi->ext_ref_frame_flags allows certain reference types to be disabled
-  // by the external interface.  These are set by av1_apply_encoding_flags().
-  // Start with what the external interface allows, then suppress any reference
-  // types which we have found to be duplicates.
-  int flags = cpi->ext_ref_frame_flags;
-
-  for (int i = 1; i < INTER_REFS_PER_FRAME; ++i) {
-    const RefCntBuffer *const this_ref = ref_frames[i];
-    // If this_ref has appeared before, mark the corresponding ref frame as
-    // invalid. For nonrd mode, only disable GOLDEN_FRAME if it's the same
-    // as LAST_FRAME or ALTREF_FRAME (if ALTREF is being used in nonrd).
-    int index = (cpi->sf.rt_sf.use_nonrd_pick_mode &&
-                 ref_frame_priority_order[i] == GOLDEN_FRAME)
-                    ? (1 + cpi->sf.rt_sf.use_nonrd_altref_frame)
-                    : i;
-    for (int j = 0; j < index; ++j) {
-      if (this_ref == ref_frames[j]) {
-        flags &= ~(1 << (ref_frame_priority_order[i] - 1));
-        break;
-      }
-    }
-  }
-  return flags;
-}
-
 static int get_current_frame_ref_type(
     const AV1_COMP *const cpi, const EncodeFrameParams *const frame_params) {
   // We choose the reference "type" of this frame from the flags which indicate
@@ -1275,6 +1238,9 @@
                                force_refresh_all);
 
   if (!is_stat_generation_stage(cpi)) {
+    const RefCntBuffer *ref_frames[INTER_REFS_PER_FRAME];
+    const YV12_BUFFER_CONFIG *ref_frame_buf[INTER_REFS_PER_FRAME];
+
     if (!cpi->ext_refresh_frame_flags_pending) {
       av1_get_ref_frames(cpi, &cpi->ref_buffer_stack);
     } else if (cpi->svc.external_ref_frame_config) {
@@ -1282,8 +1248,13 @@
         cm->remapped_ref_idx[i] = cpi->svc.ref_idx[i];
     }
 
+    // Get the reference frames
+    for (int i = 0; i < INTER_REFS_PER_FRAME; ++i) {
+      ref_frames[i] = get_ref_frame_buf(cm, ref_frame_priority_order[i]);
+      ref_frame_buf[i] = ref_frames[i] != NULL ? &ref_frames[i]->buf : NULL;
+    }
     // Work out which reference frame slots may be used.
-    frame_params.ref_frame_flags = get_ref_frame_flags(cpi);
+    frame_params.ref_frame_flags = get_ref_frame_flags(cpi, ref_frame_buf);
 
     frame_params.primary_ref_frame =
         choose_primary_ref_frame(cpi, &frame_params);
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index c87ad6e..907dc3f 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -4814,39 +4814,6 @@
   }
 }
 
-// Enforce the number of references for each arbitrary frame based on user
-// options and speed.
-static AOM_INLINE void enforce_max_ref_frames(AV1_COMP *cpi) {
-  MV_REFERENCE_FRAME ref_frame;
-  int total_valid_refs = 0;
-  for (ref_frame = LAST_FRAME; ref_frame <= ALTREF_FRAME; ++ref_frame) {
-    if (cpi->ref_frame_flags & av1_ref_frame_flag_list[ref_frame]) {
-      total_valid_refs++;
-    }
-  }
-
-  const int max_allowed_refs = get_max_allowed_ref_frames(cpi);
-
-  for (int i = 0; i < 4 && total_valid_refs > max_allowed_refs; ++i) {
-    const MV_REFERENCE_FRAME ref_frame_to_disable = disable_order[i];
-
-    if (!(cpi->ref_frame_flags &
-          av1_ref_frame_flag_list[ref_frame_to_disable])) {
-      continue;
-    }
-
-    switch (ref_frame_to_disable) {
-      case LAST3_FRAME: cpi->ref_frame_flags &= ~AOM_LAST3_FLAG; break;
-      case LAST2_FRAME: cpi->ref_frame_flags &= ~AOM_LAST2_FLAG; break;
-      case ALTREF2_FRAME: cpi->ref_frame_flags &= ~AOM_ALT2_FLAG; break;
-      case GOLDEN_FRAME: cpi->ref_frame_flags &= ~AOM_GOLD_FLAG; break;
-      default: assert(0);
-    }
-    --total_valid_refs;
-  }
-  assert(total_valid_refs <= max_allowed_refs);
-}
-
 static INLINE int av1_refs_are_one_sided(const AV1_COMMON *cm) {
   assert(!frame_is_intra_only(cm));
 
@@ -5581,7 +5548,7 @@
   }
 
   av1_setup_frame_buf_refs(cm);
-  enforce_max_ref_frames(cpi);
+  enforce_max_ref_frames(cpi, &cpi->ref_frame_flags);
   set_rel_frame_dist(cpi);
   av1_setup_frame_sign_bias(cm);
 
diff --git a/av1/encoder/encoder.h b/av1/encoder/encoder.h
index 7290364..b2b8638 100644
--- a/av1/encoder/encoder.h
+++ b/av1/encoder/encoder.h
@@ -1516,6 +1516,73 @@
                 cpi->oxcf.max_reference_frames);
 }
 
+static const MV_REFERENCE_FRAME
+    ref_frame_priority_order[INTER_REFS_PER_FRAME] = {
+      LAST_FRAME,    ALTREF_FRAME, BWDREF_FRAME, GOLDEN_FRAME,
+      ALTREF2_FRAME, LAST2_FRAME,  LAST3_FRAME,
+    };
+
+static INLINE int get_ref_frame_flags(const AV1_COMP *const cpi,
+                                      const YV12_BUFFER_CONFIG **ref_frames) {
+  // cpi->ext_ref_frame_flags allows certain reference types to be disabled
+  // by the external interface.  These are set by av1_apply_encoding_flags().
+  // Start with what the external interface allows, then suppress any reference
+  // types which we have found to be duplicates.
+  int flags = cpi->ext_ref_frame_flags;
+
+  for (int i = 1; i < INTER_REFS_PER_FRAME; ++i) {
+    const YV12_BUFFER_CONFIG *const this_ref = ref_frames[i];
+    // If this_ref has appeared before, mark the corresponding ref frame as
+    // invalid. For nonrd mode, only disable GOLDEN_FRAME if it's the same
+    // as LAST_FRAME or ALTREF_FRAME (if ALTREF is being used in nonrd).
+    int index = (cpi->sf.rt_sf.use_nonrd_pick_mode &&
+                 ref_frame_priority_order[i] == GOLDEN_FRAME)
+                    ? (1 + cpi->sf.rt_sf.use_nonrd_altref_frame)
+                    : i;
+    for (int j = 0; j < index; ++j) {
+      if (this_ref == ref_frames[j]) {
+        flags &= ~(1 << (ref_frame_priority_order[i] - 1));
+        break;
+      }
+    }
+  }
+  return flags;
+}
+
+// Enforce the number of references for each arbitrary frame based on user
+// options and speed.
+static AOM_INLINE void enforce_max_ref_frames(AV1_COMP *cpi,
+                                              int *ref_frame_flags) {
+  MV_REFERENCE_FRAME ref_frame;
+  int total_valid_refs = 0;
+
+  for (ref_frame = LAST_FRAME; ref_frame <= ALTREF_FRAME; ++ref_frame) {
+    if (*ref_frame_flags & av1_ref_frame_flag_list[ref_frame]) {
+      total_valid_refs++;
+    }
+  }
+
+  const int max_allowed_refs = get_max_allowed_ref_frames(cpi);
+
+  for (int i = 0; i < 4 && total_valid_refs > max_allowed_refs; ++i) {
+    const MV_REFERENCE_FRAME ref_frame_to_disable = disable_order[i];
+
+    if (!(*ref_frame_flags & av1_ref_frame_flag_list[ref_frame_to_disable])) {
+      continue;
+    }
+
+    switch (ref_frame_to_disable) {
+      case LAST3_FRAME: *ref_frame_flags &= ~AOM_LAST3_FLAG; break;
+      case LAST2_FRAME: *ref_frame_flags &= ~AOM_LAST2_FLAG; break;
+      case ALTREF2_FRAME: *ref_frame_flags &= ~AOM_ALT2_FLAG; break;
+      case GOLDEN_FRAME: *ref_frame_flags &= ~AOM_GOLD_FLAG; break;
+      default: assert(0);
+    }
+    --total_valid_refs;
+  }
+  assert(total_valid_refs <= max_allowed_refs);
+}
+
 // Returns a Sequence Header OBU stored in an aom_fixed_buf_t, or NULL upon
 // failure. When a non-NULL aom_fixed_buf_t pointer is returned by this
 // function, the memory must be freed by the caller. Both the buf member of the
diff --git a/av1/encoder/tpl_model.c b/av1/encoder/tpl_model.c
index 829f519..7600f09 100644
--- a/av1/encoder/tpl_model.c
+++ b/av1/encoder/tpl_model.c
@@ -671,10 +671,10 @@
   const YV12_BUFFER_CONFIG *this_frame = tpl_frame->gf_picture;
   const YV12_BUFFER_CONFIG *ref_frame[7] = { NULL, NULL, NULL, NULL,
                                              NULL, NULL, NULL };
+  const YV12_BUFFER_CONFIG *ref_frames_ordered[INTER_REFS_PER_FRAME];
   unsigned int ref_frame_display_index[7];
   MV_REFERENCE_FRAME ref[2] = { LAST_FRAME, INTRA_FRAME };
-  const int max_allowed_refs = get_max_allowed_ref_frames(cpi);
-  int total_valid_refs = 0;
+  int ref_frame_flags;
   const YV12_BUFFER_CONFIG *src_frame[7] = { NULL, NULL, NULL, NULL,
                                              NULL, NULL, NULL };
 
@@ -706,17 +706,21 @@
     src_frame[idx] = cpi->tpl_frame[tpl_frame->ref_map_index[idx]].gf_picture;
   }
 
-  // Remove duplicate frames
-  for (int idx1 = 0; idx1 < INTER_REFS_PER_FRAME; ++idx1) {
-    for (int idx2 = idx1 + 1; idx2 < INTER_REFS_PER_FRAME; ++idx2) {
-      if (ref_frame[idx1] == ref_frame[idx2]) {
-        ref_frame[idx2] = NULL;
-      }
-    }
+  // Store the reference frames based on priority order
+  for (int i = 0; i < INTER_REFS_PER_FRAME; ++i) {
+    ref_frames_ordered[i] = ref_frame[ref_frame_priority_order[i] - 1];
   }
 
+  // Work out which reference frame slots may be used.
+  ref_frame_flags = get_ref_frame_flags(cpi, ref_frames_ordered);
+
+  enforce_max_ref_frames(cpi, &ref_frame_flags);
+
+  // Prune reference frames
   for (idx = 0; idx < INTER_REFS_PER_FRAME; ++idx) {
-    if (ref_frame[idx] != NULL) total_valid_refs++;
+    if ((ref_frame_flags & (1 << idx)) == 0) {
+      ref_frame[idx] = NULL;
+    }
   }
 
   // Skip motion estimation w.r.t. reference frames which are not
@@ -729,13 +733,6 @@
     }
   }
 
-  // Skip reference frames based on user options and speed.
-  for (idx = 0; idx < 4 && total_valid_refs > max_allowed_refs; ++idx) {
-    const MV_REFERENCE_FRAME ref_frame_to_disable = disable_order[idx];
-    ref_frame[ref_frame_to_disable - 1] = NULL;
-    total_valid_refs--;
-  }
-
   // Make a temporary mbmi for tpl model
   MB_MODE_INFO mbmi;
   memset(&mbmi, 0, sizeof(mbmi));