Rework tpl model strcture

Rework the data structure and control flow used to build the tpl
model system.

Change-Id: Ia80ea70d62ce38ab6a776b0f2e08c3886611cc0c
diff --git a/av1/encoder/encode_strategy.c b/av1/encoder/encode_strategy.c
index e050f29..6c5fcd1 100644
--- a/av1/encoder/encode_strategy.c
+++ b/av1/encoder/encode_strategy.c
@@ -758,8 +758,7 @@
 }
 #endif  // DUMP_REF_FRAME_IMAGES == 1
 
-static int get_refresh_ref_frame_map(AV1_COMMON *const cm) {
-  int refresh_frame_flags = cm->current_frame.refresh_frame_flags;
+int av1_get_refresh_ref_frame_map(int refresh_frame_flags) {
   int ref_map_index = INVALID_IDX;
 
   for (ref_map_index = 0; ref_map_index < REF_FRAMES; ++ref_map_index)
@@ -816,58 +815,57 @@
 // our reference frame management strategy.
 void av1_update_ref_frame_map(AV1_COMP *cpi,
                               FRAME_UPDATE_TYPE frame_update_type,
+                              int ref_map_index,
                               RefBufferStack *ref_buffer_stack) {
   AV1_COMMON *const cm = &cpi->common;
-
-  int ref_map_index = get_refresh_ref_frame_map(cm);
-
-  if (cm->current_frame.frame_type == KEY_FRAME && cm->show_frame) {
-    stack_reset(ref_buffer_stack->lst_stack, &ref_buffer_stack->lst_stack_size);
-    stack_reset(ref_buffer_stack->gld_stack, &ref_buffer_stack->gld_stack_size);
-    stack_reset(ref_buffer_stack->arf_stack, &ref_buffer_stack->arf_stack_size);
-
-    stack_push(ref_buffer_stack->gld_stack, &ref_buffer_stack->gld_stack_size,
-               ref_map_index);
-  } else {
-    switch (frame_update_type) {
-      case GF_UPDATE:
+  switch (frame_update_type) {
+    case KEY_FRAME:
+      stack_reset(ref_buffer_stack->lst_stack,
+                  &ref_buffer_stack->lst_stack_size);
+      stack_reset(ref_buffer_stack->gld_stack,
+                  &ref_buffer_stack->gld_stack_size);
+      stack_reset(ref_buffer_stack->arf_stack,
+                  &ref_buffer_stack->arf_stack_size);
+      stack_push(ref_buffer_stack->gld_stack, &ref_buffer_stack->gld_stack_size,
+                 ref_map_index);
+      break;
+    case GF_UPDATE:
+      update_arf_stack(cpi, ref_map_index, ref_buffer_stack);
+      stack_push(ref_buffer_stack->gld_stack, &ref_buffer_stack->gld_stack_size,
+                 ref_map_index);
+      break;
+    case LF_UPDATE:
+      update_arf_stack(cpi, ref_map_index, ref_buffer_stack);
+      stack_push(ref_buffer_stack->lst_stack, &ref_buffer_stack->lst_stack_size,
+                 ref_map_index);
+      break;
+    case ARF_UPDATE:
+    case INTNL_ARF_UPDATE:
+      update_arf_stack(cpi, ref_map_index, ref_buffer_stack);
+      stack_push(ref_buffer_stack->arf_stack, &ref_buffer_stack->arf_stack_size,
+                 ref_map_index);
+      break;
+    case OVERLAY_UPDATE:
+      if (cpi->preserve_arf_as_gld || cm->show_existing_frame) {
+        ref_map_index = stack_pop(ref_buffer_stack->arf_stack,
+                                  &ref_buffer_stack->arf_stack_size);
+        stack_push(ref_buffer_stack->gld_stack,
+                   &ref_buffer_stack->gld_stack_size, ref_map_index);
+      } else {
+        stack_pop(ref_buffer_stack->arf_stack,
+                  &ref_buffer_stack->arf_stack_size);
         update_arf_stack(cpi, ref_map_index, ref_buffer_stack);
         stack_push(ref_buffer_stack->gld_stack,
                    &ref_buffer_stack->gld_stack_size, ref_map_index);
-        break;
-      case LF_UPDATE:
-        update_arf_stack(cpi, ref_map_index, ref_buffer_stack);
-        stack_push(ref_buffer_stack->lst_stack,
-                   &ref_buffer_stack->lst_stack_size, ref_map_index);
-        break;
-      case ARF_UPDATE:
-      case INTNL_ARF_UPDATE:
-        update_arf_stack(cpi, ref_map_index, ref_buffer_stack);
-        stack_push(ref_buffer_stack->arf_stack,
-                   &ref_buffer_stack->arf_stack_size, ref_map_index);
-        break;
-      case OVERLAY_UPDATE:
-        if (cpi->preserve_arf_as_gld || cm->show_existing_frame) {
-          ref_map_index = stack_pop(ref_buffer_stack->arf_stack,
-                                    &ref_buffer_stack->arf_stack_size);
-          stack_push(ref_buffer_stack->gld_stack,
-                     &ref_buffer_stack->gld_stack_size, ref_map_index);
-        } else {
-          stack_pop(ref_buffer_stack->arf_stack,
-                    &ref_buffer_stack->arf_stack_size);
-          update_arf_stack(cpi, ref_map_index, ref_buffer_stack);
-          stack_push(ref_buffer_stack->gld_stack,
-                     &ref_buffer_stack->gld_stack_size, ref_map_index);
-        }
-        break;
-      case INTNL_OVERLAY_UPDATE:
-        ref_map_index = stack_pop(ref_buffer_stack->arf_stack,
-                                  &ref_buffer_stack->arf_stack_size);
-        stack_push(ref_buffer_stack->lst_stack,
-                   &ref_buffer_stack->lst_stack_size, ref_map_index);
-        break;
-      default: assert(0 && "unknown type");
-    }
+      }
+      break;
+    case INTNL_OVERLAY_UPDATE:
+      ref_map_index = stack_pop(ref_buffer_stack->arf_stack,
+                                &ref_buffer_stack->arf_stack_size);
+      stack_push(ref_buffer_stack->lst_stack, &ref_buffer_stack->lst_stack_size,
+                 ref_map_index);
+      break;
+    default: assert(0 && "unknown type");
   }
 
   if (!cpi->ext_refresh_frame_flags_pending) return;
@@ -1066,7 +1064,7 @@
   // buffer management strategy currently in use.  This function just decides
   // which buffers should be refreshed.
 
-  int free_fb_index = get_free_ref_map_index(&cpi->ref_buffer_stack);
+  int free_fb_index = get_free_ref_map_index(ref_buffer_stack);
   switch (frame_update_type) {
     case KF_UPDATE:
     case GF_UPDATE:
@@ -1604,7 +1602,7 @@
       frame_params.frame_type == KEY_FRAME && frame_params.show_frame) {
     av1_configure_buffer_updates(cpi, &frame_params, frame_update_type, 0);
     av1_set_frame_size(cpi, cm->width, cm->height);
-    av1_tpl_setup_stats(cpi, &frame_input);
+    av1_tpl_setup_stats(cpi, &frame_params, &frame_input);
   }
 #endif  // ENABLE_KF_TPL
 
@@ -1616,7 +1614,7 @@
       if (cpi->gf_group.index == 1 && cpi->oxcf.enable_tpl_model) {
         av1_configure_buffer_updates(cpi, &frame_params, frame_update_type, 0);
         av1_set_frame_size(cpi, cm->width, cm->height);
-        av1_tpl_setup_stats(cpi, &frame_input);
+        av1_tpl_setup_stats(cpi, &frame_params, &frame_input);
         assert(cpi->num_gf_group_show_frames == 1);
       }
     }
@@ -1640,7 +1638,10 @@
     // First pass doesn't modify reference buffer assignment or produce frame
     // flags
     update_frame_flags(cpi, frame_flags);
-    av1_update_ref_frame_map(cpi, frame_update_type, &cpi->ref_buffer_stack);
+    int ref_map_index =
+        av1_get_refresh_ref_frame_map(cm->current_frame.refresh_frame_flags);
+    av1_update_ref_frame_map(cpi, frame_update_type, ref_map_index,
+                             &cpi->ref_buffer_stack);
   }
 
 #if !CONFIG_REALTIME_ONLY
diff --git a/av1/encoder/encode_strategy.h b/av1/encoder/encode_strategy.h
index 9f54f95..1954a5e 100644
--- a/av1/encoder/encode_strategy.h
+++ b/av1/encoder/encode_strategy.h
@@ -45,8 +45,11 @@
                                 FRAME_UPDATE_TYPE frame_update_type,
                                 const RefBufferStack *const ref_buffer_stack);
 
+int av1_get_refresh_ref_frame_map(int refresh_frame_flags);
+
 void av1_update_ref_frame_map(AV1_COMP *cpi,
                               FRAME_UPDATE_TYPE frame_update_type,
+                              int ref_map_index,
                               RefBufferStack *ref_buffer_stack);
 
 void av1_get_ref_frames(AV1_COMP *const cpi,
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index 8ce5a11..70264d0 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -3500,7 +3500,7 @@
   assert(IMPLIES(cpi->gf_group.size > 0,
                  cpi->gf_group.index < cpi->gf_group.size));
   const int tpl_idx = cpi->gf_group.frame_disp_idx[cpi->gf_group.index];
-  TplDepFrame *tpl_frame = &cpi->tpl_stats[tpl_idx];
+  TplDepFrame *tpl_frame = &cpi->tpl_frame[tpl_idx];
   TplDepStats *tpl_stats = tpl_frame->tpl_stats_ptr;
   int tpl_stride = tpl_frame->stride;
   int64_t intra_cost = 0;
@@ -3577,7 +3577,7 @@
   assert(IMPLIES(cpi->gf_group.size > 0,
                  cpi->gf_group.index < cpi->gf_group.size));
   const int tpl_idx = cpi->gf_group.frame_disp_idx[cpi->gf_group.index];
-  TplDepFrame *tpl_frame = &cpi->tpl_stats[tpl_idx];
+  TplDepFrame *tpl_frame = &cpi->tpl_frame[tpl_idx];
   TplDepStats *tpl_stats = tpl_frame->tpl_stats_ptr;
   int tpl_stride = tpl_frame->stride;
   int64_t intra_cost = 0;
diff --git a/av1/encoder/encoder.c b/av1/encoder/encoder.c
index 81a00b0..8016404 100644
--- a/av1/encoder/encoder.c
+++ b/av1/encoder/encoder.c
@@ -2789,16 +2789,18 @@
     int mi_cols = ALIGN_POWER_OF_TWO(cm->mi_cols, MAX_MIB_SIZE_LOG2);
     int mi_rows = ALIGN_POWER_OF_TWO(cm->mi_rows, MAX_MIB_SIZE_LOG2);
 
-    CHECK_MEM_ERROR(cm, cpi->tpl_stats[frame].tpl_stats_ptr,
-                    aom_calloc(mi_rows * mi_cols,
-                               sizeof(*cpi->tpl_stats[frame].tpl_stats_ptr)));
-    cpi->tpl_stats[frame].is_valid = 0;
-    cpi->tpl_stats[frame].width = mi_cols;
-    cpi->tpl_stats[frame].height = mi_rows;
-    cpi->tpl_stats[frame].stride = mi_cols;
-    cpi->tpl_stats[frame].mi_rows = cm->mi_rows;
-    cpi->tpl_stats[frame].mi_cols = cm->mi_cols;
+    CHECK_MEM_ERROR(
+        cm, cpi->tpl_stats_buffer[frame].tpl_stats_ptr,
+        aom_calloc(mi_rows * mi_cols,
+                   sizeof(*cpi->tpl_stats_buffer[frame].tpl_stats_ptr)));
+    cpi->tpl_stats_buffer[frame].is_valid = 0;
+    cpi->tpl_stats_buffer[frame].width = mi_cols;
+    cpi->tpl_stats_buffer[frame].height = mi_rows;
+    cpi->tpl_stats_buffer[frame].stride = mi_cols;
+    cpi->tpl_stats_buffer[frame].mi_rows = cm->mi_rows;
+    cpi->tpl_stats_buffer[frame].mi_cols = cm->mi_cols;
   }
+  cpi->tpl_frame = &cpi->tpl_stats_buffer[REF_FRAMES];
 
 #if CONFIG_COLLECT_PARTITION_STATS == 2
   av1_zero(cpi->partition_stats);
@@ -3130,8 +3132,8 @@
   }
 
   for (int frame = 0; frame < MAX_LENGTH_TPL_FRAME_STATS; ++frame) {
-    aom_free(cpi->tpl_stats[frame].tpl_stats_ptr);
-    cpi->tpl_stats[frame].is_valid = 0;
+    aom_free(cpi->tpl_stats_buffer[frame].tpl_stats_ptr);
+    cpi->tpl_stats_buffer[frame].is_valid = 0;
   }
 
   for (t = cpi->num_workers - 1; t >= 0; --t) {
@@ -3635,7 +3637,7 @@
 
   assert(IMPLIES(gf_group->size > 0, gf_group->index < gf_group->size));
   const int tpl_idx = gf_group->frame_disp_idx[gf_group->index];
-  TplDepFrame *tpl_frame = &cpi->tpl_stats[tpl_idx];
+  TplDepFrame *tpl_frame = &cpi->tpl_frame[tpl_idx];
   TplDepStats *tpl_stats = tpl_frame->tpl_stats_ptr;
 
   if (tpl_frame->is_valid) {
diff --git a/av1/encoder/encoder.h b/av1/encoder/encoder.h
index f8e7313..d73799d 100644
--- a/av1/encoder/encoder.h
+++ b/av1/encoder/encoder.h
@@ -157,7 +157,7 @@
   SS_CFG_TOTAL = 2
 } UENUM1BYTE(SS_CFG_OFFSET);
 
-#define MAX_LENGTH_TPL_FRAME_STATS 27
+#define MAX_LENGTH_TPL_FRAME_STATS (27 + 8)
 
 typedef struct TplDepStats {
   int64_t intra_cost;
@@ -175,6 +175,7 @@
 typedef struct TplDepFrame {
   uint8_t is_valid;
   TplDepStats *tpl_stats_ptr;
+  const YV12_BUFFER_CONFIG *gf_picture;
   int stride;
   int width;
   int height;
@@ -753,7 +754,8 @@
   YV12_BUFFER_CONFIG *unscaled_last_source;
   YV12_BUFFER_CONFIG scaled_last_source;
 
-  TplDepFrame tpl_stats[MAX_LENGTH_TPL_FRAME_STATS];
+  TplDepFrame tpl_stats_buffer[MAX_LENGTH_TPL_FRAME_STATS];
+  TplDepFrame *tpl_frame;
 
   // For a still frame, this flag is set to 1 to skip partition search.
   int partition_search_skippable_frame;
diff --git a/av1/encoder/tpl_model.c b/av1/encoder/tpl_model.c
index 7821b01..27ce86e 100644
--- a/av1/encoder/tpl_model.c
+++ b/av1/encoder/tpl_model.c
@@ -443,7 +443,7 @@
   const GF_GROUP *gf_group = &cpi->gf_group;
   if (frame_idx == gf_group->size) return;
   int tpl_idx = gf_group->frame_disp_idx[frame_idx];
-  TplDepFrame *tpl_frame = &cpi->tpl_stats[tpl_idx];
+  TplDepFrame *tpl_frame = &cpi->tpl_frame[tpl_idx];
   YV12_BUFFER_CONFIG *this_frame = gf_picture[frame_idx];
   YV12_BUFFER_CONFIG *ref_frame[7] = {
     NULL, NULL, NULL, NULL, NULL, NULL, NULL
@@ -551,7 +551,7 @@
                       tpl_frame->stride, &tpl_stats);
 
       if (frame_idx)
-        tpl_model_update(cpi->tpl_stats, tpl_frame->tpl_stats_ptr, mi_row,
+        tpl_model_update(cpi->tpl_frame, tpl_frame->tpl_stats_ptr, mi_row,
                          mi_col, bsize);
     }
   }
@@ -559,10 +559,10 @@
 
 #define REF_IDX(ref) ((ref)-LAST_FRAME)
 
-static void init_gop_frames_for_tpl(AV1_COMP *cpi,
-                                    YV12_BUFFER_CONFIG **gf_picture,
-                                    GF_GROUP *gf_group, int *tpl_group_frames,
-                                    const EncodeFrameInput *const frame_input) {
+static void init_gop_frames_for_tpl(
+    AV1_COMP *cpi, const EncodeFrameParams *const init_frame_params,
+    YV12_BUFFER_CONFIG **gf_picture, GF_GROUP *gf_group, int *tpl_group_frames,
+    const EncodeFrameInput *const frame_input) {
   AV1_COMMON *cm = &cpi->common;
   const SequenceHeader *const seq_params = &cm->seq_params;
   int frame_idx = 0;
@@ -571,6 +571,9 @@
   int pframe_qindex = 0;
   int cur_frame_idx = gf_group->index;
 
+  RefBufferStack ref_buffer_stack = cpi->ref_buffer_stack;
+  EncodeFrameParams frame_params = *init_frame_params;
+
   for (int i = 0; i < FRAME_BUFFERS && frame_idx < INTER_REFS_PER_FRAME + 1;
        ++i) {
     if (frame_bufs[i].ref_count == 0) {
@@ -586,6 +589,29 @@
     }
   }
 
+  for (int i = LAST_FRAME; i < REF_FRAMES; ++i)
+    cpi->tpl_frame[-i].gf_picture = get_ref_frame_yv12_buf(cm, i);
+
+  for (int gf_index = gf_group->index; gf_index < gf_group->size; ++gf_index) {
+    FRAME_UPDATE_TYPE frame_update_type = gf_group->update_type[gf_index];
+    frame_params.show_frame = frame_update_type != ARF_UPDATE &&
+                              frame_update_type != INTNL_ARF_UPDATE;
+    frame_params.show_existing_frame =
+        frame_update_type == INTNL_OVERLAY_UPDATE;
+    frame_params.frame_type =
+        frame_update_type == KF_UPDATE ? KEY_FRAME : INTER_FRAME;
+
+    av1_get_ref_frames(cpi, frame_update_type, &ref_buffer_stack);
+    int refresh_mask = av1_get_refresh_frame_flags(
+        cpi, &frame_params, frame_update_type, &ref_buffer_stack);
+    int ref_map_index = av1_get_refresh_ref_frame_map(refresh_mask);
+    av1_update_ref_frame_map(cpi, frame_update_type, ref_map_index,
+                             &ref_buffer_stack);
+  }
+
+  av1_get_ref_frames(cpi, gf_group->update_type[gf_group->index],
+                     &cpi->ref_buffer_stack);
+
   *tpl_group_frames = 0;
 
   if (cur_frame_idx > 0) {
@@ -699,7 +725,7 @@
 
 static void init_tpl_stats(AV1_COMP *cpi) {
   for (int frame_idx = 0; frame_idx < MAX_LENGTH_TPL_FRAME_STATS; ++frame_idx) {
-    TplDepFrame *tpl_frame = &cpi->tpl_stats[frame_idx];
+    TplDepFrame *tpl_frame = &cpi->tpl_stats_buffer[frame_idx];
     memset(tpl_frame->tpl_stats_ptr, 0,
            tpl_frame->height * tpl_frame->width *
                sizeof(*tpl_frame->tpl_stats_ptr));
@@ -708,12 +734,15 @@
 }
 
 void av1_tpl_setup_stats(AV1_COMP *cpi,
+                         const EncodeFrameParams *const frame_params,
                          const EncodeFrameInput *const frame_input) {
-  YV12_BUFFER_CONFIG *gf_picture[MAX_LENGTH_TPL_FRAME_STATS];
+  YV12_BUFFER_CONFIG
+  *gf_picture_buffer[MAX_LENGTH_TPL_FRAME_STATS + REF_FRAMES];
+  YV12_BUFFER_CONFIG **gf_picture = &gf_picture_buffer[REF_FRAMES];
   GF_GROUP *gf_group = &cpi->gf_group;
 
-  init_gop_frames_for_tpl(cpi, gf_picture, gf_group, &cpi->tpl_gf_group_frames,
-                          frame_input);
+  init_gop_frames_for_tpl(cpi, frame_params, gf_picture, gf_group,
+                          &cpi->tpl_gf_group_frames, frame_input);
 
   init_tpl_stats(cpi);
 
@@ -932,7 +961,7 @@
   const GF_GROUP *gf_group = &cpi->gf_group;
   assert(IMPLIES(gf_group->size > 0, gf_group->index < gf_group->size));
   const int tpl_cur_idx = gf_group->frame_disp_idx[gf_group->index];
-  TplDepFrame *tpl_frame = &cpi->tpl_stats[tpl_cur_idx];
+  TplDepFrame *tpl_frame = &cpi->tpl_frame[tpl_cur_idx];
   memset(
       tpl_frame->tpl_stats_ptr, 0,
       tpl_frame->height * tpl_frame->width * sizeof(*tpl_frame->tpl_stats_ptr));
diff --git a/av1/encoder/tpl_model.h b/av1/encoder/tpl_model.h
index 4732d1c..d089b3f 100644
--- a/av1/encoder/tpl_model.h
+++ b/av1/encoder/tpl_model.h
@@ -17,6 +17,7 @@
 #endif
 
 void av1_tpl_setup_stats(AV1_COMP *cpi,
+                         const EncodeFrameParams *const frame_params,
                          const EncodeFrameInput *const frame_input);
 
 void av1_tpl_setup_forward_stats(AV1_COMP *cpi);