Rework the tpl model construction

Support forward building and backward synthesis to build the
temporal dependency model.

Change-Id: Id04d838647eaf366c3a2b2a25028cdcdadcb7d87
diff --git a/av1/encoder/encoder.h b/av1/encoder/encoder.h
index 611c71e..df29b54 100644
--- a/av1/encoder/encoder.h
+++ b/av1/encoder/encoder.h
@@ -166,6 +166,9 @@
   int64_t inter_cost;
   int64_t mc_flow;
   int64_t mc_dep_cost;
+  int_mv mv;
+  int ref_frame_index;
+  double quant_ratio;
 #if !USE_TPL_CLASSIC_MODEL
   int64_t mc_count;
   int64_t mc_saved;
diff --git a/av1/encoder/tpl_model.c b/av1/encoder/tpl_model.c
index 7cbc824..a7fadf7 100644
--- a/av1/encoder/tpl_model.c
+++ b/av1/encoder/tpl_model.c
@@ -139,8 +139,7 @@
     int frame_idx, int16_t *src_diff, tran_low_t *coeff, tran_low_t *qcoeff,
     tran_low_t *dqcoeff, int use_satd, int mi_row, int mi_col, BLOCK_SIZE bsize,
     TX_SIZE tx_size, const YV12_BUFFER_CONFIG *ref_frame[], uint8_t *predictor,
-    int64_t *recon_error, int64_t *sse, TplDepStats *tpl_stats,
-    int *ref_frame_index, int_mv *mv) {
+    int64_t *recon_error, int64_t *sse, TplDepStats *tpl_stats) {
   AV1_COMMON *cm = &cpi->common;
   const GF_GROUP *gf_group = &cpi->gf_group;
 
@@ -323,8 +322,9 @@
   tpl_stats->intra_cost = best_intra_cost << TPL_DEP_COST_SCALE_LOG2;
 
   if (frame_idx && best_rf_idx != -1) {
-    *ref_frame_index = cpi->tpl_frame[frame_idx].ref_map_index[best_rf_idx];
-    mv->as_int = best_mv.as_int;
+    tpl_stats->mv.as_int = best_mv.as_int;
+    tpl_stats->ref_frame_index =
+        cpi->tpl_frame[frame_idx].ref_map_index[best_rf_idx];
   }
 }
 
@@ -383,14 +383,12 @@
 static AOM_INLINE void tpl_model_update_b(AV1_COMP *cpi, TplDepFrame *tpl_frame,
                                           TplDepStats *tpl_stats_ptr,
                                           int mi_row, int mi_col,
-                                          double quant_ratio,
-                                          const BLOCK_SIZE bsize,
-                                          int ref_frame_index, int_mv mv) {
-  TplDepFrame *ref_tpl_frame = &tpl_frame[ref_frame_index];
+                                          const BLOCK_SIZE bsize) {
+  TplDepFrame *ref_tpl_frame = &tpl_frame[tpl_stats_ptr->ref_frame_index];
   TplDepStats *ref_stats_ptr = ref_tpl_frame->tpl_stats_ptr;
 
-  const int ref_pos_row = mi_row * MI_SIZE + (mv.as_mv.row >> 3);
-  const int ref_pos_col = mi_col * MI_SIZE + (mv.as_mv.col >> 3);
+  const int ref_pos_row = mi_row * MI_SIZE + (tpl_stats_ptr->mv.as_mv.row >> 3);
+  const int ref_pos_col = mi_col * MI_SIZE + (tpl_stats_ptr->mv.as_mv.col >> 3);
 
   const int bw = 4 << mi_size_wide_log2[bsize];
   const int bh = 4 << mi_size_high_log2[bsize];
@@ -416,8 +414,11 @@
 
       const double iiratio_nl = iiratio_nonlinear(
           (double)tpl_stats_ptr->inter_cost / tpl_stats_ptr->intra_cost);
-      int64_t mc_flow = (int64_t)(quant_ratio * tpl_stats_ptr->mc_dep_cost *
-                                  (1.0 - iiratio_nl));
+      tpl_stats_ptr->mc_dep_cost =
+          tpl_stats_ptr->intra_cost + tpl_stats_ptr->mc_flow;
+      int64_t mc_flow =
+          (int64_t)(tpl_stats_ptr->quant_ratio * tpl_stats_ptr->mc_dep_cost *
+                    (1.0 - iiratio_nl));
 #if !USE_TPL_CLASSIC_MODEL
       int64_t mc_saved = tpl_stats_ptr->intra_cost - tpl_stats_ptr->inter_cost;
 #endif  // #if !USE_TPL_CLASSIC_MODEL
@@ -440,9 +441,7 @@
 
 static AOM_INLINE void tpl_model_update(AV1_COMP *cpi, TplDepFrame *tpl_frame,
                                         TplDepStats *tpl_stats_ptr, int mi_row,
-                                        int mi_col, double quant_ratio,
-                                        const BLOCK_SIZE bsize,
-                                        int ref_frame_index, int_mv mv) {
+                                        int mi_col, const BLOCK_SIZE bsize) {
   const int mi_height = mi_size_high[bsize];
   const int mi_width = mi_size_wide[bsize];
   const int step = 1 << cpi->tpl_stats_block_mis_log2;
@@ -454,7 +453,7 @@
       TplDepStats *tpl_ptr = &tpl_stats_ptr[av1_tpl_ptr_pos(
           cpi, mi_row + idy, mi_col + idx, tpl_frame->stride)];
       tpl_model_update_b(cpi, tpl_frame, tpl_ptr, mi_row + idy, mi_col + idx,
-                         quant_ratio, tpl_block_size, ref_frame_index, mv);
+                         tpl_block_size);
     }
   }
 }
@@ -469,19 +468,18 @@
 
   int64_t intra_cost = src_stats->intra_cost / (mi_height * mi_width);
   int64_t inter_cost = src_stats->inter_cost / (mi_height * mi_width);
-
-  TplDepStats *tpl_ptr;
-
   intra_cost = AOMMAX(1, intra_cost);
   inter_cost = AOMMAX(1, inter_cost);
 
   for (int idy = 0; idy < mi_height; idy += step) {
-    tpl_ptr =
+    TplDepStats *tpl_ptr =
         &tpl_stats_ptr[av1_tpl_ptr_pos(cpi, mi_row + idy, mi_col, stride)];
     for (int idx = 0; idx < mi_width; idx += step) {
       tpl_ptr->intra_cost = intra_cost;
       tpl_ptr->inter_cost = inter_cost;
-      tpl_ptr->mc_dep_cost = tpl_ptr->intra_cost + tpl_ptr->mc_flow;
+      tpl_ptr->quant_ratio = src_stats->quant_ratio;
+      tpl_ptr->mv.as_int = src_stats->mv.as_int;
+      tpl_ptr->ref_frame_index = src_stats->ref_frame_index;
       ++tpl_ptr;
     }
   }
@@ -608,10 +606,6 @@
     xd->mb_to_bottom_edge = ((cm->mi_rows - mi_height - mi_row) * MI_SIZE) * 8;
     for (mi_col = 0; mi_col < cm->mi_cols; mi_col += mi_width) {
       TplDepStats tpl_stats;
-      int ref_frame_index = -1;
-      int_mv mv;
-
-      mv.as_int = 0;
 
       // Motion estimation column boundary
       x->mv_limits.col_min =
@@ -622,17 +616,13 @@
       xd->mb_to_right_edge = ((cm->mi_cols - mi_width - mi_col) * MI_SIZE) * 8;
       mode_estimation(cpi, x, xd, &sf, frame_idx, src_diff, coeff, qcoeff,
                       dqcoeff, 1, mi_row, mi_col, bsize, tx_size, ref_frame,
-                      predictor, &recon_error, &sse, &tpl_stats,
-                      &ref_frame_index, &mv);
+                      predictor, &recon_error, &sse, &tpl_stats);
 
       // Motion flow dependency dispenser.
+      double quant_ratio = (double)recon_error / sse;
+      tpl_stats.quant_ratio = quant_ratio;
       tpl_model_store(cpi, tpl_frame->tpl_stats_ptr, mi_row, mi_col, bsize,
                       tpl_frame->stride, &tpl_stats);
-      double quant_ratio = (double)recon_error / sse;
-      if (frame_idx) {
-        tpl_model_update(cpi, cpi->tpl_frame, tpl_frame->tpl_stats_ptr, mi_row,
-                         mi_col, quant_ratio, bsize, ref_frame_index, mv);
-      }
     }
   }
 
@@ -640,6 +630,28 @@
   xd->mi = backup_mi_grid;
 }
 
+static void mc_flow_synthesizer(AV1_COMP *cpi, int frame_idx) {
+  AV1_COMMON *cm = &cpi->common;
+
+  const GF_GROUP *gf_group = &cpi->gf_group;
+  if (frame_idx == gf_group->size) return;
+
+  TplDepFrame *tpl_frame = &cpi->tpl_frame[frame_idx];
+
+  const BLOCK_SIZE bsize = convert_length_to_bsize(MC_FLOW_BSIZE_1D);
+  const int mi_height = mi_size_high[bsize];
+  const int mi_width = mi_size_wide[bsize];
+
+  for (int mi_row = 0; mi_row < cm->mi_rows; mi_row += mi_height) {
+    for (int mi_col = 0; mi_col < cm->mi_cols; mi_col += mi_width) {
+      if (frame_idx) {
+        tpl_model_update(cpi, cpi->tpl_frame, tpl_frame->tpl_stats_ptr, mi_row,
+                         mi_col, bsize);
+      }
+    }
+  }
+}
+
 static AOM_INLINE void init_gop_frames_for_tpl(
     AV1_COMP *cpi, const EncodeFrameParams *const init_frame_params,
     GF_GROUP *gf_group, int *tpl_group_frames,
@@ -832,13 +844,22 @@
 
   if (cpi->oxcf.enable_tpl_model == 1) {
     // Backward propagation from tpl_group_frames to 1.
+    for (int frame_idx = gf_group->index; frame_idx < cpi->tpl_gf_group_frames;
+         ++frame_idx) {
+      if (gf_group->update_type[frame_idx] == INTNL_OVERLAY_UPDATE ||
+          gf_group->update_type[frame_idx] == OVERLAY_UPDATE)
+        continue;
+
+      mc_flow_dispenser(cpi, frame_idx);
+    }
+
     for (int frame_idx = cpi->tpl_gf_group_frames - 1;
          frame_idx >= gf_group->index; --frame_idx) {
       if (gf_group->update_type[frame_idx] == INTNL_OVERLAY_UPDATE ||
           gf_group->update_type[frame_idx] == OVERLAY_UPDATE)
         continue;
 
-      mc_flow_dispenser(cpi, frame_idx);
+      mc_flow_synthesizer(cpi, frame_idx);
     }
   }