Use two tpl passes to estimate the propagation factor

Change-Id: I808033427cae547987c2c3205aefd92035710b75
diff --git a/av1/qmode_rc/ratectrl_qmode.cc b/av1/qmode_rc/ratectrl_qmode.cc
index b11a609..8a02d2e 100644
--- a/av1/qmode_rc/ratectrl_qmode.cc
+++ b/av1/qmode_rc/ratectrl_qmode.cc
@@ -899,7 +899,8 @@
 }
 
 TplFrameDepStats CreateTplFrameDepStats(int frame_height, int frame_width,
-                                        int min_block_size) {
+                                        int min_block_size,
+                                        bool has_alt_stats) {
   const int unit_rows = (frame_height + min_block_size - 1) / min_block_size;
   const int unit_cols = (frame_width + min_block_size - 1) / min_block_size;
   TplFrameDepStats frame_dep_stats;
@@ -908,6 +909,13 @@
   for (auto &row : frame_dep_stats.unit_stats) {
     row.resize(unit_cols);
   }
+
+  if (has_alt_stats) {
+    frame_dep_stats.alt_unit_stats.resize(unit_rows);
+    for (auto &row : frame_dep_stats.alt_unit_stats) {
+      row.resize(unit_cols);
+    }
+  }
   return frame_dep_stats;
 }
 
@@ -919,7 +927,10 @@
   // In rare case, inter_cost may be greater than intra_cost.
   // If so, we need to modify inter_cost such that inter_cost <= intra_cost
   // because it is required by GetPropagationFraction()
-  dep_stats.inter_cost = std::min(dep_stats.intra_cost, dep_stats.inter_cost);
+  if (block_stats.ref_frame_index[0] >= 0)
+    dep_stats.inter_cost = std::min(dep_stats.intra_cost, dep_stats.inter_cost);
+  else
+    dep_stats.inter_cost = dep_stats.intra_cost;
   dep_stats.mv = block_stats.mv;
   dep_stats.ref_frame_index = block_stats.ref_frame_index;
   return dep_stats;
@@ -985,19 +996,15 @@
 }
 }  // namespace
 
-StatusOr<TplFrameDepStats> CreateTplFrameDepStatsWithoutPropagation(
-    const TplFrameStats &frame_stats) {
-  if (frame_stats.block_stats_list.empty()) {
-    return TplFrameDepStats();
-  }
+Status FillTplUnitDepStats(TplFrameDepStats &frame_dep_stats,
+                           const TplFrameStats &frame_stats,
+                           const std::vector<TplBlockStats> &block_stats_list) {
   const int min_block_size = frame_stats.min_block_size;
   const int unit_rows =
       (frame_stats.frame_height + min_block_size - 1) / min_block_size;
   const int unit_cols =
       (frame_stats.frame_width + min_block_size - 1) / min_block_size;
-  TplFrameDepStats frame_dep_stats = CreateTplFrameDepStats(
-      frame_stats.frame_height, frame_stats.frame_width, min_block_size);
-  for (const TplBlockStats &block_stats : frame_stats.block_stats_list) {
+  for (const TplBlockStats &block_stats : block_stats_list) {
     Status status =
         ValidateBlockStats(frame_stats, block_stats, min_block_size);
     if (!status.ok()) {
@@ -1022,8 +1029,32 @@
       }
     }
   }
+  return { AOM_CODEC_OK, "" };
+}
 
-  frame_dep_stats.rdcost = TplFrameDepStatsAccumulateInterCost(frame_dep_stats);
+StatusOr<TplFrameDepStats> CreateTplFrameDepStatsWithoutPropagation(
+    const TplFrameStats &frame_stats) {
+  if (frame_stats.block_stats_list.empty()) {
+    return TplFrameDepStats();
+  }
+  const int min_block_size = frame_stats.min_block_size;
+  TplFrameDepStats frame_dep_stats = CreateTplFrameDepStats(
+      frame_stats.frame_height, frame_stats.frame_width, min_block_size,
+      !frame_stats.alternate_block_stats_list.empty());
+
+  Status status = FillTplUnitDepStats(frame_dep_stats, frame_stats,
+                                      frame_stats.block_stats_list);
+  if (!status.ok()) return status;
+
+  if (!frame_stats.alternate_block_stats_list.empty()) {
+    status = FillTplUnitDepStats(frame_dep_stats, frame_stats,
+                                 frame_stats.alternate_block_stats_list);
+    if (!status.ok()) return status;
+    frame_dep_stats.rdcost =
+        TplFrameDepStatsAccumulateInterCost(frame_dep_stats.unit_stats);
+    frame_dep_stats.alt_rdcost =
+        TplFrameDepStatsAccumulateInterCost(frame_dep_stats.alt_unit_stats);
+  }
 
   return frame_dep_stats;
 }
@@ -1070,12 +1101,12 @@
 }
 
 double TplFrameDepStatsAccumulateInterCost(
-    const TplFrameDepStats &frame_dep_stats) {
+    const std::vector<std::vector<TplUnitDepStats>> &unit_stats) {
   auto getInterCost = [](double sum, const TplUnitDepStats &unit) {
     return sum + unit.inter_cost;
   };
   double sum = 0;
-  for (const auto &row : frame_dep_stats.unit_stats) {
+  for (const auto &row : unit_stats) {
     sum = std::accumulate(row.begin(), row.end(), sum, getInterCost);
   }
   return std::max(sum, 1.0);
@@ -1427,10 +1458,7 @@
       if (gop_frame.update_type == GopFrameType::kRegularGolden ||
           gop_frame.update_type == GopFrameType::kRegularKey ||
           gop_frame.update_type == GopFrameType::kRegularArf) {
-        double qstep_ratio = 1 / 3.0;
-        param.q_index = av1_get_q_index_from_qstep_ratio(
-            rc_param_.base_q_index, qstep_ratio, AOM_BITS_8);
-        if (rc_param_.base_q_index) param.q_index = AOMMAX(param.q_index, 1);
+        param.q_index = 5;
       }
     }
     gop_encode_info.param_list.push_back(param);
@@ -1464,6 +1492,36 @@
       static_cast<int>(tpl_gop_stats.frame_stats_list.size());
   const int active_worst_quality = rc_param_.base_q_index;
   int active_best_quality = rc_param_.base_q_index;
+
+  double base_rdcost = 1.0;  // baseline total rdcost
+  double hqr_rdcost = 0;     // high quality reference total rdcost
+  double arf_rdcost_high = 1.0;
+  double arf_rdcost_low = 0;
+
+  bool kf_arf_seen = false;
+
+  for (int i = 0; i < frame_count; ++i) {
+    FrameEncodeParameters param;
+    const GopFrame &gop_frame = gop_struct.gop_frame_list[i];
+    const TplFrameDepStats &frame_dep_stats =
+        gop_dep_stats->frame_dep_stats_list[i];
+    if (gop_frame.update_type == GopFrameType::kRegularGolden ||
+        gop_frame.update_type == GopFrameType::kRegularKey ||
+        gop_frame.update_type == GopFrameType::kRegularArf) {
+      if (!kf_arf_seen) {
+        arf_rdcost_high += frame_dep_stats.rdcost;
+        arf_rdcost_low += frame_dep_stats.alt_rdcost;
+      }
+      kf_arf_seen = 1;
+    } else {
+      base_rdcost += frame_dep_stats.rdcost;
+      hqr_rdcost += frame_dep_stats.alt_rdcost;
+    }
+  }
+
+  double tp_frame_importance =
+      1.0 + fabs((base_rdcost - hqr_rdcost) / arf_rdcost_high);
+
   for (int i = 0; i < frame_count; i++) {
     FrameEncodeParameters param;
     const GopFrame &gop_frame = gop_struct.gop_frame_list[i];
@@ -1481,8 +1539,15 @@
           TplFrameDepStatsAccumulateIntraCost(frame_dep_stats);
       const double cost_with_propagation =
           TplFrameDepStatsAccumulate(frame_dep_stats);
-      const double frame_importance =
+      double frame_importance =
           cost_with_propagation / cost_without_propagation;
+
+      // TODO(jingning): Temporarily make the switch between single and
+      // two TPL passes depending on the availability. This part of code
+      // needs further modifications to support SB level calculation.
+      if (rc_param_.tpl_pass_count == TplPassCount::kTwoTplPasses)
+        frame_importance = tp_frame_importance;
+
       // Imitate the behavior of av1_tpl_get_qstep_ratio()
       const double qstep_ratio = sqrt(1 / frame_importance);
       param.q_index = av1_get_q_index_from_qstep_ratio(rc_param_.base_q_index,
diff --git a/av1/qmode_rc/ratectrl_qmode.h b/av1/qmode_rc/ratectrl_qmode.h
index f60000e..70e112a 100644
--- a/av1/qmode_rc/ratectrl_qmode.h
+++ b/av1/qmode_rc/ratectrl_qmode.h
@@ -35,9 +35,11 @@
 };
 
 struct TplFrameDepStats {
-  int unit_size;  // equivalent to min_block_size
-  double rdcost;  // overall rate-distortion cost
+  int unit_size;      // equivalent to min_block_size
+  double rdcost;      // overall rate-distortion cost
+  double alt_rdcost;  // rate-distortion cost in the second tpl pass
   std::vector<std::vector<TplUnitDepStats>> unit_stats;
+  std::vector<std::vector<TplUnitDepStats>> alt_unit_stats;
 };
 
 struct TplGopDepStats {
@@ -66,11 +68,15 @@
 // and blocks along the bottom or right edge of the frame may extend beyond the
 // edges of the frame.
 TplFrameDepStats CreateTplFrameDepStats(int frame_height, int frame_width,
-                                        int min_block_size);
+                                        int min_block_size, bool has_alt_stats);
 
 TplUnitDepStats TplBlockStatsToDepStats(const TplBlockStats &block_stats,
                                         int unit_count);
 
+Status FillTplUnitDepStats(TplFrameDepStats &frame_dep_stats,
+                           const TplFrameStats &frame_stats,
+                           const std::vector<TplBlockStats> &block_stats_list);
+
 StatusOr<TplFrameDepStats> CreateTplFrameDepStatsWithoutPropagation(
     const TplFrameStats &frame_stats);
 
@@ -80,7 +86,7 @@
     const TplFrameDepStats &frame_dep_stats);
 
 double TplFrameDepStatsAccumulateInterCost(
-    const TplFrameDepStats &frame_dep_stats);
+    const std::vector<std::vector<TplUnitDepStats>> &unit_stats);
 
 double TplFrameDepStatsAccumulate(const TplFrameDepStats &frame_dep_stats);
 
diff --git a/test/ratectrl_qmode_test.cc b/test/ratectrl_qmode_test.cc
index 3c947b0..dcf72d2 100644
--- a/test/ratectrl_qmode_test.cc
+++ b/test/ratectrl_qmode_test.cc
@@ -314,7 +314,7 @@
 
 static TplFrameStats CreateToyTplFrameStatsWithDiffSizes(int min_block_size,
                                                          int max_block_size) {
-  TplFrameStats frame_stats;
+  TplFrameStats frame_stats = {};
   const int max_h = max_block_size;
   const int max_w = max_h;
   const int count = max_block_size / min_block_size;
@@ -455,7 +455,7 @@
   // ref frame with coding_idx 0
   TplFrameDepStats frame_dep_stats0 =
       CreateTplFrameDepStats(frame_stats.frame_height, frame_stats.frame_width,
-                             frame_stats.min_block_size);
+                             frame_stats.min_block_size, false);
   gop_dep_stats.frame_dep_stats_list.push_back(frame_dep_stats0);
 
   // cur frame with coding_idx 1
@@ -492,13 +492,13 @@
   // ref frame with coding_idx 0
   const TplFrameDepStats frame_dep_stats0 =
       CreateTplFrameDepStats(frame_stats.frame_height, frame_stats.frame_width,
-                             frame_stats.min_block_size);
+                             frame_stats.min_block_size, false);
   gop_dep_stats.frame_dep_stats_list.push_back(frame_dep_stats0);
 
   // ref frame with coding_idx 1
   const TplFrameDepStats frame_dep_stats1 =
       CreateTplFrameDepStats(frame_stats.frame_height, frame_stats.frame_width,
-                             frame_stats.min_block_size);
+                             frame_stats.min_block_size, false);
   gop_dep_stats.frame_dep_stats_list.push_back(frame_dep_stats1);
 
   // cur frame with coding_idx 2
@@ -546,7 +546,7 @@
   // ref frame with coding_idx 0
   gop_dep_stats.frame_dep_stats_list.push_back(
       CreateTplFrameDepStats(frame_stats.frame_height, frame_stats.frame_width,
-                             frame_stats.min_block_size));
+                             frame_stats.min_block_size, false));
 
   // cur frame with coding_idx 1
   const StatusOr<TplFrameDepStats> frame_dep_stats =