Fix tpl with new subgop refresh rule

Change-Id: I3173174df2ef81b2ddcc4fa1b758af91b62fe6d8
diff --git a/av1/encoder/encode_strategy.c b/av1/encoder/encode_strategy.c
index 9a19b28..3ae25db 100644
--- a/av1/encoder/encode_strategy.c
+++ b/av1/encoder/encode_strategy.c
@@ -273,21 +273,6 @@
   return primary_ref_frame;
 }
 
-static INLINE int get_true_pyr_level(int frame_level, int frame_order,
-                                     int max_layer_depth) {
-  if (frame_order == 0) {
-    // Keyframe case
-    return 1;
-  } else if (frame_level == MAX_ARF_LAYERS) {
-    // Leaves
-    return max_layer_depth;
-  } else if (frame_level == (MAX_ARF_LAYERS + 1)) {
-    // Altrefs
-    return 1;
-  }
-  return frame_level;
-}
-
 // Map the subgop cfg reference list to actual reference buffers. Disable
 // any reference frames that are not listed in the sub gop.
 static void get_gop_cfg_enabled_refs(AV1_COMP *const cpi, int *ref_frame_flags,
@@ -890,12 +875,9 @@
   return INVALID_IDX;
 }
 
-static int get_refresh_idx(const AV1_COMP *const cpi,
-                           const EncodeFrameParams *const frame_params,
-                           int update_arf, int refresh_level) {
-  const int order_offset = frame_params->order_offset;
-  const int cur_frame_disp =
-      cpi->common.current_frame.frame_number + order_offset;
+static int get_refresh_idx(int update_arf, int refresh_level,
+                           int cur_frame_disp,
+                           RefFrameMapPair ref_frame_map_pairs[REF_FRAMES]) {
   int arf_count = 0;
   int oldest_arf_order = INT32_MAX;
   int oldest_arf_idx = -1;
@@ -907,14 +889,11 @@
   int oldest_ref_level_idx = -1;
 
   for (int map_idx = 0; map_idx < REF_FRAMES; map_idx++) {
-    // Get reference frame buffer
-    const RefCntBuffer *const buf =
-        (map_idx != INVALID_IDX) ? cpi->common.ref_frame_map[map_idx] : NULL;
-    if (buf == NULL) continue;
-    const int frame_order = (int)buf->display_order_hint;
+    RefFrameMapPair ref_pair = ref_frame_map_pairs[map_idx];
+    if (ref_pair.disp_order == -1) continue;
+    const int frame_order = ref_pair.disp_order;
+    const int reference_frame_level = ref_pair.pyr_level;
     if (frame_order > cur_frame_disp) continue;
-    const int reference_frame_level = get_true_pyr_level(
-        buf->pyramid_level, frame_order, cpi->gf_group.max_layer_depth);
 
     // Keep track of the oldest reference frame matching the specified
     // refresh level from the subgop cfg
@@ -953,9 +932,9 @@
 }
 
 static int get_refresh_frame_flags_subgop_cfg(
-    const AV1_COMP *const cpi, const EncodeFrameParams *const frame_params,
-    const RefBufferStack *const ref_buffer_stack, int gf_index,
-    int refresh_mask, int free_fb_index) {
+    const AV1_COMP *const cpi, int gf_index, int cur_disp_order,
+    RefFrameMapPair ref_frame_map_pairs[REF_FRAMES], int refresh_mask,
+    int free_fb_index) {
   const SubGOPStepCfg *step_gop_cfg = get_subgop_step(&cpi->gf_group, gf_index);
   assert(step_gop_cfg != NULL);
   const int pyr_level = step_gop_cfg->pyr_level;
@@ -970,63 +949,18 @@
     return refresh_mask;
   }
 
-  // TODO(sarahparker) Fix compatibility with tpl
-  if (!cpi->oxcf.algo_cfg.enable_tpl_model) {
-    const int update_arf =
-        type_code == FRAME_TYPE_OOO_FILTERED && pyr_level == 1;
-    const int refresh_level = step_gop_cfg->refresh;
-    const int refresh_idx =
-        get_refresh_idx(cpi, frame_params, update_arf, refresh_level);
-    return 1 << refresh_idx;
-  }
-
-  switch (type_code) {
-    case FRAME_TYPE_INO_VISIBLE:
-      if (ref_buffer_stack->lst_stack_size >= 2)
-        refresh_mask =
-            1 << ref_buffer_stack
-                     ->lst_stack[ref_buffer_stack->lst_stack_size - 1];
-      else if (ref_buffer_stack->gld_stack_size >= 2)
-        refresh_mask =
-            1 << ref_buffer_stack
-                     ->gld_stack[ref_buffer_stack->gld_stack_size - 1];
-      else
-        assert(0 && "No ref map index found");
-      break;
-    case FRAME_TYPE_OOO_FILTERED:
-      if (pyr_level == 1) {
-        if (ref_buffer_stack->gld_stack_size >= 3)
-          refresh_mask =
-              1 << ref_buffer_stack
-                       ->gld_stack[ref_buffer_stack->gld_stack_size - 1];
-        else if (ref_buffer_stack->lst_stack_size >= 2)
-          refresh_mask =
-              1 << ref_buffer_stack
-                       ->lst_stack[ref_buffer_stack->lst_stack_size - 1];
-        else
-          assert(0 && "No ref map index found");
-
-      } else {
-        refresh_mask =
-            1 << ref_buffer_stack
-                     ->lst_stack[ref_buffer_stack->lst_stack_size - 1];
-      }
-      break;
-    case FRAME_TYPE_OOO_UNFILTERED:
-      refresh_mask = 1 << ref_buffer_stack
-                              ->lst_stack[ref_buffer_stack->lst_stack_size - 1];
-      break;
-    case FRAME_TYPE_INO_REPEAT:
-    case FRAME_TYPE_INO_SHOWEXISTING:
-    default: assert(0); break;
-  }
-  return refresh_mask;
+  const int update_arf = type_code == FRAME_TYPE_OOO_FILTERED && pyr_level == 1;
+  const int refresh_level = step_gop_cfg->refresh;
+  const int refresh_idx = get_refresh_idx(update_arf, refresh_level,
+                                          cur_disp_order, ref_frame_map_pairs);
+  return 1 << refresh_idx;
 }
 
 int av1_get_refresh_frame_flags(const AV1_COMP *const cpi,
                                 const EncodeFrameParams *const frame_params,
                                 FRAME_UPDATE_TYPE frame_update_type,
-                                int gf_index,
+                                int gf_index, int cur_disp_order,
+                                RefFrameMapPair ref_frame_map_pairs[REF_FRAMES],
                                 const RefBufferStack *const ref_buffer_stack) {
   const AV1_COMMON *const cm = &cpi->common;
   const ExtRefreshFrameFlagsInfo *const ext_refresh_frame_flags =
@@ -1098,9 +1032,9 @@
   int free_fb_index = get_free_ref_map_index(ref_buffer_stack);
 
   if (use_subgop_cfg(&cpi->gf_group, gf_index)) {
-    return get_refresh_frame_flags_subgop_cfg(cpi, frame_params,
-                                              ref_buffer_stack, gf_index,
-                                              refresh_mask, free_fb_index);
+    return get_refresh_frame_flags_subgop_cfg(cpi, gf_index, cur_disp_order,
+                                              ref_frame_map_pairs, refresh_mask,
+                                              free_fb_index);
   }
   switch (frame_update_type) {
     case KF_UPDATE:
@@ -1619,9 +1553,15 @@
                                frame_params.order_offset);
     }
 
+    const int cur_frame_disp =
+        cpi->common.current_frame.frame_number + frame_params.order_offset;
+
+    RefFrameMapPair ref_frame_map_pairs[REF_FRAMES];
+    init_ref_map_pair(cpi, ref_frame_map_pairs);
+
     frame_params.refresh_frame_flags = av1_get_refresh_frame_flags(
         cpi, &frame_params, frame_update_type, cpi->gf_group.index,
-        &cpi->ref_buffer_stack);
+        cur_frame_disp, ref_frame_map_pairs, &cpi->ref_buffer_stack);
 
     frame_params.existing_fb_idx_to_show =
         frame_params.show_existing_frame
diff --git a/av1/encoder/encode_strategy.h b/av1/encoder/encode_strategy.h
index e17ac2f..59ee094 100644
--- a/av1/encoder/encode_strategy.h
+++ b/av1/encoder/encode_strategy.h
@@ -69,7 +69,8 @@
 int av1_get_refresh_frame_flags(const AV1_COMP *const cpi,
                                 const EncodeFrameParams *const frame_params,
                                 FRAME_UPDATE_TYPE frame_update_type,
-                                int gf_index,
+                                int gf_index, int cur_frame_disp,
+                                RefFrameMapPair ref_frame_map_pairs[REF_FRAMES],
                                 const RefBufferStack *const ref_buffer_stack);
 
 int av1_get_refresh_ref_frame_map(int refresh_frame_flags);
diff --git a/av1/encoder/encoder.h b/av1/encoder/encoder.h
index 4e6584e..22807ed 100644
--- a/av1/encoder/encoder.h
+++ b/av1/encoder/encoder.h
@@ -2820,6 +2820,41 @@
 void av1_set_screen_content_options(const struct AV1_COMP *cpi,
                                     FeatureFlags *features);
 
+static INLINE int get_true_pyr_level(int frame_level, int frame_order,
+                                     int max_layer_depth) {
+  if (frame_order == 0) {
+    // Keyframe case
+    return 1;
+  } else if (frame_level == MAX_ARF_LAYERS) {
+    // Leaves
+    return max_layer_depth;
+  } else if (frame_level == (MAX_ARF_LAYERS + 1)) {
+    // Altrefs
+    return 1;
+  }
+  return frame_level;
+}
+
+typedef struct {
+  int pyr_level;
+  int disp_order;
+} RefFrameMapPair;
+
+static INLINE void init_ref_map_pair(
+    AV1_COMP *cpi, RefFrameMapPair ref_frame_map_pairs[REF_FRAMES]) {
+  memset(ref_frame_map_pairs, -1, sizeof(*ref_frame_map_pairs) * REF_FRAMES);
+  for (int map_idx = 0; map_idx < REF_FRAMES; map_idx++) {
+    // Get reference frame buffer
+    const RefCntBuffer *const buf = cpi->common.ref_frame_map[map_idx];
+    if (buf == NULL) continue;
+    ref_frame_map_pairs[map_idx].disp_order = (int)buf->display_order_hint;
+    const int reference_frame_level = get_true_pyr_level(
+        buf->pyramid_level, ref_frame_map_pairs[map_idx].disp_order,
+        cpi->gf_group.max_layer_depth);
+    ref_frame_map_pairs[map_idx].pyr_level = reference_frame_level;
+  }
+}
+
 // TODO(jingning): Move these functions as primitive members for the new cpi
 // class.
 static INLINE void stack_push(int *stack, int *stack_size, int item) {
diff --git a/av1/encoder/tpl_model.c b/av1/encoder/tpl_model.c
index d5f1f4e..2b5df07 100644
--- a/av1/encoder/tpl_model.c
+++ b/av1/encoder/tpl_model.c
@@ -927,6 +927,8 @@
   AV1_COMMON *cm = &cpi->common;
   int cur_frame_idx = gf_group->index;
   *pframe_qindex = 0;
+  RefFrameMapPair ref_frame_map_pairs[REF_FRAMES];
+  init_ref_map_pair(cpi, ref_frame_map_pairs);
 
   RefBufferStack ref_buffer_stack = cpi->ref_buffer_stack;
   EncodeFrameParams frame_params = *init_frame_params;
@@ -1006,14 +1008,24 @@
     }
 
     av1_get_ref_frames(cpi, &ref_buffer_stack);
+    const int true_disp =
+        (int)(tpl_frame->frame_display_index) - frame_params.show_frame;
     int refresh_mask = av1_get_refresh_frame_flags(
-        cpi, &frame_params, frame_update_type, gf_index, &ref_buffer_stack);
+        cpi, &frame_params, frame_update_type, gf_index, true_disp,
+        ref_frame_map_pairs, &ref_buffer_stack);
 
     int refresh_frame_map_index = av1_get_refresh_ref_frame_map(refresh_mask);
     av1_update_ref_frame_map(cpi, frame_update_type, frame_params.frame_type,
                              gf_index, frame_params.show_existing_frame,
                              refresh_frame_map_index, &ref_buffer_stack);
 
+    if (refresh_frame_map_index < REF_FRAMES) {
+      ref_frame_map_pairs[refresh_frame_map_index].disp_order =
+          AOMMAX(0, true_disp);
+      ref_frame_map_pairs[refresh_frame_map_index].pyr_level =
+          gf_group->layer_depth[gf_index];
+    }
+
     for (int i = LAST_FRAME; i <= ALTREF_FRAME; ++i)
       tpl_frame->ref_map_index[i - LAST_FRAME] =
           ref_picture_map[cm->remapped_ref_idx[i - LAST_FRAME]];
@@ -1071,8 +1083,11 @@
     // av1_update_ref_frame_map() will execute default behavior even when
     // subgop cfg is enabled. This should be addressed if we ever remove the
     // frame_update_type.
+    const int true_disp =
+        (int)(tpl_frame->frame_display_index) - frame_params.show_frame;
     int refresh_mask = av1_get_refresh_frame_flags(
-        cpi, &frame_params, frame_update_type, -1, &ref_buffer_stack);
+        cpi, &frame_params, frame_update_type, -1, true_disp,
+        ref_frame_map_pairs, &ref_buffer_stack);
     int refresh_frame_map_index = av1_get_refresh_ref_frame_map(refresh_mask);
     av1_update_ref_frame_map(cpi, frame_update_type, frame_params.frame_type,
                              -1, frame_params.show_existing_frame,