Build coding block stats for two-pass TPL runs

Convert the data flow in the two-pass TPL to support the design
of the superblock level QP offsets.

Change-Id: I926fa2fe0d9842c189b6dd9ad0beabcb7333caed
diff --git a/av1/qmode_rc/ratectrl_qmode.cc b/av1/qmode_rc/ratectrl_qmode.cc
index 28cd7a6..d78069a 100644
--- a/av1/qmode_rc/ratectrl_qmode.cc
+++ b/av1/qmode_rc/ratectrl_qmode.cc
@@ -1147,6 +1147,87 @@
          ModifyDivisor(unit_dep_stats.intra_cost);
 }
 
+void TplFrameDepStatsBackTrace(int coding_idx,
+                               const RefFrameTable &ref_frame_table,
+                               TplGopDepStats *tpl_gop_dep_stats) {
+  assert(!tpl_gop_dep_stats->frame_dep_stats_list.empty());
+  TplFrameDepStats *frame_dep_stats =
+      &tpl_gop_dep_stats->frame_dep_stats_list[coding_idx];
+
+  if (frame_dep_stats->unit_stats.empty()) return;
+  if (frame_dep_stats->alt_unit_stats.empty()) return;
+
+  const int unit_size = frame_dep_stats->unit_size;
+  const int frame_unit_rows =
+      static_cast<int>(frame_dep_stats->unit_stats.size());
+  const int frame_unit_cols =
+      static_cast<int>(frame_dep_stats->unit_stats[0].size());
+  for (int unit_row = 0; unit_row < frame_unit_rows; ++unit_row) {
+    for (int unit_col = 0; unit_col < frame_unit_cols; ++unit_col) {
+      TplUnitDepStats &unit_dep_stats =
+          frame_dep_stats->unit_stats[unit_row][unit_col];
+      TplUnitDepStats &alt_unit_dep_stats =
+          frame_dep_stats->alt_unit_stats[unit_row][unit_col];
+
+      int ref_coding_idx_list[kBlockRefCount] = { -1, -1 };
+      int ref_frame_count = GetRefCodingIdxList(
+          alt_unit_dep_stats, ref_frame_table, ref_coding_idx_list);
+      if (ref_frame_count == 0) continue;
+      MotionVector base_mv[2] = { alt_unit_dep_stats.mv[0],
+                                  alt_unit_dep_stats.mv[1] };
+      for (int i = 0; i < kBlockRefCount; ++i) {
+        if (ref_coding_idx_list[i] == -1) continue;
+        assert(
+            ref_coding_idx_list[i] <
+            static_cast<int>(tpl_gop_dep_stats->frame_dep_stats_list.size()));
+        TplFrameDepStats &ref_frame_dep_stats =
+            tpl_gop_dep_stats->frame_dep_stats_list[ref_coding_idx_list[i]];
+        assert(!ref_frame_dep_stats.alt_unit_stats.empty());
+        const auto &mv = base_mv[i];
+        const int mv_row = GetFullpelValue(mv.row, mv.subpel_bits);
+        const int mv_col = GetFullpelValue(mv.col, mv.subpel_bits);
+        const int ref_pixel_r = unit_row * unit_size + mv_row;
+        const int ref_pixel_c = unit_col * unit_size + mv_col;
+        const int ref_unit_row_low =
+            (unit_row * unit_size + mv_row) / unit_size;
+        const int ref_unit_col_low =
+            (unit_col * unit_size + mv_col) / unit_size;
+
+        for (int j = 0; j < 2; ++j) {
+          for (int k = 0; k < 2; ++k) {
+            const int ref_unit_row = ref_unit_row_low + j;
+            const int ref_unit_col = ref_unit_col_low + k;
+            if (ref_unit_row >= 0 && ref_unit_row < frame_unit_rows &&
+                ref_unit_col >= 0 && ref_unit_col < frame_unit_cols) {
+              const int overlap_area = GetBlockOverlapArea(
+                  ref_pixel_r, ref_pixel_c, ref_unit_row * unit_size,
+                  ref_unit_col * unit_size, unit_size);
+              const double overlap_ratio =
+                  overlap_area * 1.0 / (unit_size * unit_size);
+              const double propagation_ratio =
+                  1.0 / ref_frame_count * overlap_ratio;
+              TplUnitDepStats &alt_ref_unit_stats =
+                  ref_frame_dep_stats
+                      .alt_unit_stats[ref_unit_row][ref_unit_col];
+              alt_ref_unit_stats.propagation_cost +=
+                  (alt_unit_dep_stats.inter_cost +
+                   alt_unit_dep_stats.propagation_cost) *
+                  propagation_ratio;
+
+              TplUnitDepStats &ref_unit_stats =
+                  ref_frame_dep_stats.unit_stats[ref_unit_row][ref_unit_col];
+              ref_unit_stats.propagation_cost +=
+                  (unit_dep_stats.inter_cost +
+                   unit_dep_stats.propagation_cost) *
+                  propagation_ratio;
+            }
+          }
+        }
+      }
+    }
+  }
+}
+
 void TplFrameDepStatsPropagate(int coding_idx,
                                const RefFrameTable &ref_frame_table,
                                TplGopDepStats *tpl_gop_dep_stats) {
@@ -1306,7 +1387,17 @@
     auto &ref_frame_table = ref_frame_table_list[coding_idx];
     // TODO(angiebird): Handle/test the case where reference frame
     // is in the previous GOP
-    TplFrameDepStatsPropagate(coding_idx, ref_frame_table, &tpl_gop_dep_stats);
+
+    if (tpl_frame_stats_list_with_lookahead[coding_idx]
+            ->alternate_block_stats_list.empty()) {
+      // One pass TPL run
+      TplFrameDepStatsPropagate(coding_idx, ref_frame_table,
+                                &tpl_gop_dep_stats);
+    } else {
+      // Two pass TPL runs
+      TplFrameDepStatsBackTrace(coding_idx, ref_frame_table,
+                                &tpl_gop_dep_stats);
+    }
   }
   return tpl_gop_dep_stats;
 }
@@ -1314,7 +1405,8 @@
 static std::vector<uint8_t> SetupDeltaQ(const TplFrameDepStats &frame_dep_stats,
                                         int frame_width, int frame_height,
                                         int base_qindex,
-                                        double frame_importance) {
+                                        double frame_importance,
+                                        bool use_twopass_data) {
   // TODO(jianj) : Add support to various superblock sizes.
   const int sb_size = 64;
   const int delta_q_res = 4;
@@ -1326,6 +1418,38 @@
   const int unit_cols =
       (frame_width + frame_dep_stats.unit_size - 1) / frame_dep_stats.unit_size;
   std::vector<uint8_t> superblock_q_indices;
+
+  if (use_twopass_data) {
+    // Cumulate frame level stats
+    double cum_inter_cost = 0;
+    double cum_rdcost_diff = 0;
+    for (int sb_row = 0; sb_row < sb_rows; ++sb_row) {
+      for (int sb_col = 0; sb_col < sb_cols; ++sb_col) {
+        const int unit_row_start = sb_row * num_unit_per_sb;
+        const int unit_row_end =
+            std::min((sb_row + 1) * num_unit_per_sb, unit_rows);
+        const int unit_col_start = sb_col * num_unit_per_sb;
+        const int unit_col_end =
+            std::min((sb_col + 1) * num_unit_per_sb, unit_cols);
+        // A simplified version of av1_get_q_for_deltaq_objective()
+        for (int unit_row = unit_row_start; unit_row < unit_row_end;
+             ++unit_row) {
+          for (int unit_col = unit_col_start; unit_col < unit_col_end;
+               ++unit_col) {
+            const TplUnitDepStats &unit_dep_stats =
+                frame_dep_stats.unit_stats[unit_row][unit_col];
+            const TplUnitDepStats &alt_unit_dep_stats =
+                frame_dep_stats.alt_unit_stats[unit_row][unit_col];
+            cum_inter_cost += unit_dep_stats.inter_cost;
+            cum_rdcost_diff += (unit_dep_stats.propagation_cost -
+                                alt_unit_dep_stats.propagation_cost);
+          }
+        }
+      }
+    }
+    frame_importance = (cum_rdcost_diff + cum_inter_cost) / cum_inter_cost;
+  }
+
   // Calculate delta_q offset for each superblock.
   for (int sb_row = 0; sb_row < sb_rows; ++sb_row) {
     for (int sb_col = 0; sb_col < sb_cols; ++sb_col) {
@@ -1341,10 +1465,16 @@
       for (int unit_row = unit_row_start; unit_row < unit_row_end; ++unit_row) {
         for (int unit_col = unit_col_start; unit_col < unit_col_end;
              ++unit_col) {
-          const TplUnitDepStats &unit_dep_stat =
+          const TplUnitDepStats &unit_dep_stats =
               frame_dep_stats.unit_stats[unit_row][unit_col];
-          intra_cost += unit_dep_stat.intra_cost;
-          mc_dep_cost += unit_dep_stat.propagation_cost;
+          intra_cost += unit_dep_stats.inter_cost;
+          mc_dep_cost += unit_dep_stats.propagation_cost;
+
+          if (use_twopass_data) {
+            const TplUnitDepStats &alt_unit_dep_stats =
+                frame_dep_stats.alt_unit_stats[unit_row][unit_col];
+            mc_dep_cost -= alt_unit_dep_stats.propagation_cost;
+          }
         }
       }
 
@@ -1882,9 +2012,11 @@
       active_best_quality = param.q_index;
 
       if (rc_param_.max_distinct_q_indices_per_frame > 1) {
-        std::vector<uint8_t> superblock_q_indices = SetupDeltaQ(
+        std::vector<uint8_t> superblock_q_indices;
+        superblock_q_indices = SetupDeltaQ(
             frame_dep_stats, rc_param_.frame_width, rc_param_.frame_height,
-            param.q_index, frame_importance);
+            param.q_index, frame_importance,
+            rc_param_.tpl_pass_count == TplPassCount::kTwoTplPasses);
         std::unordered_map<int, int> qindex_centroids = internal::KMeans(
             superblock_q_indices, rc_param_.max_distinct_q_indices_per_frame);
         for (size_t i = 0; i < superblock_q_indices.size(); ++i) {
diff --git a/av1/qmode_rc/ratectrl_qmode.h b/av1/qmode_rc/ratectrl_qmode.h
index 4f2180c..29623d6 100644
--- a/av1/qmode_rc/ratectrl_qmode.h
+++ b/av1/qmode_rc/ratectrl_qmode.h
@@ -92,6 +92,10 @@
 
 double TplFrameDepStatsAccumulate(const TplFrameDepStats &frame_dep_stats);
 
+void TplFrameDepStatsBackTrace(int coding_idx,
+                               const RefFrameTable &ref_frame_table,
+                               TplGopDepStats *tpl_gop_dep_stats);
+
 void TplFrameDepStatsPropagate(int coding_idx,
                                const RefFrameTable &ref_frame_table,
                                TplGopDepStats *tpl_gop_dep_stats);