Implement the function that uses fp stats to determine arf q.

BUG=b/260859962

Change-Id: Ifab636761c1d5c2078406e32121d3cc4282bb6f9
diff --git a/av1/qmode_rc/ratectrl_qmode.cc b/av1/qmode_rc/ratectrl_qmode.cc
index 7e10be4..6ca6d81 100644
--- a/av1/qmode_rc/ratectrl_qmode.cc
+++ b/av1/qmode_rc/ratectrl_qmode.cc
@@ -851,6 +851,18 @@
   return gf_intervals;
 }
 
+// Make a copy of the first pass stats, and analyze them
+FirstpassInfo AnalyzeFpStats(FirstpassInfo firstpass_info) {
+  const int stats_size = static_cast<int>(firstpass_info.stats_list.size());
+  av1_mark_flashes(firstpass_info.stats_list.data(),
+                   firstpass_info.stats_list.data() + stats_size);
+  av1_estimate_noise(firstpass_info.stats_list.data(),
+                     firstpass_info.stats_list.data() + stats_size);
+  av1_estimate_coeff(firstpass_info.stats_list.data(),
+                     firstpass_info.stats_list.data() + stats_size);
+  return firstpass_info;
+}
+
 StatusOr<GopStructList> AV1RateControlQMode::DetermineGopInfo(
     const FirstpassInfo &firstpass_info) {
   const int stats_size = static_cast<int>(firstpass_info.stats_list.size());
@@ -858,18 +870,12 @@
   RefFrameManager ref_frame_manager(rc_param_.ref_frame_table_size,
                                     rc_param_.max_ref_frames);
 
-  // Make a copy of the first pass stats, and analyze them
-  FirstpassInfo fp_info_copy = firstpass_info;
-  av1_mark_flashes(fp_info_copy.stats_list.data(),
-                   fp_info_copy.stats_list.data() + stats_size);
-  av1_estimate_noise(fp_info_copy.stats_list.data(),
-                     fp_info_copy.stats_list.data() + stats_size);
-  av1_estimate_coeff(fp_info_copy.stats_list.data(),
-                     fp_info_copy.stats_list.data() + stats_size);
+  const FirstpassInfo analyzed_fp_info =
+      AnalyzeFpStats(std::move(firstpass_info));
 
   int global_coding_idx_offset = 0;
   int global_order_idx_offset = 0;
-  std::vector<int> key_frame_list = GetKeyFrameList(fp_info_copy);
+  std::vector<int> key_frame_list = GetKeyFrameList(analyzed_fp_info);
   key_frame_list.push_back(stats_size);  // a sentinel value
   for (size_t ki = 0; ki + 1 < key_frame_list.size(); ++ki) {
     int frames_to_key = key_frame_list[ki + 1] - key_frame_list[ki];
@@ -877,11 +883,11 @@
 
     std::vector<REGIONS> regions_list(MAX_FIRSTPASS_ANALYSIS_FRAMES);
     int total_regions = 0;
-    av1_identify_regions(fp_info_copy.stats_list.data() + key_order_index,
+    av1_identify_regions(analyzed_fp_info.stats_list.data() + key_order_index,
                          frames_to_key, 0, regions_list.data(), &total_regions);
     regions_list.resize(total_regions);
     std::vector<int> gf_intervals = PartitionGopIntervals(
-        rc_param_, fp_info_copy.stats_list, regions_list, key_order_index,
+        rc_param_, analyzed_fp_info.stats_list, regions_list, key_order_index,
         /*frames_since_key=*/0, frames_to_key);
     for (size_t gi = 0; gi < gf_intervals.size(); ++gi) {
       const bool has_key_frame = gi == 0;
@@ -1466,12 +1472,151 @@
   return gop_encode_info;
 }
 
+bool CheckFlash(const std::vector<FIRSTPASS_STATS> &stats_list, int index) {
+  assert(index >= 1);
+  return stats_list[index].is_flash || stats_list[index - 1].is_flash;
+}
+
 StatusOr<GopEncodeInfo> AV1RateControlQMode::GetGopEncodeInfoWithFp(
-    const GopStruct &gop_struct,
-    const FirstpassInfo &firstpass_info AOM_UNUSED) {
-  // TODO(b/260859962): This is currently a placeholder. Should use the fp
-  // stats to calculate frame-level qp.
-  return GetGopEncodeInfoWithNoStats(gop_struct);
+    const GopStruct &gop_struct, const FirstpassInfo &firstpass_info,
+    const std::vector<LookaheadStats> &lookahead_stats) {
+  const int stats_size = static_cast<int>(firstpass_info.stats_list.size());
+  const FirstpassInfo analyzed_fp_info =
+      AnalyzeFpStats(std::move(firstpass_info));
+
+  const int this_gop_len = gop_struct.show_frame_count;
+  const int next_gop_len =
+      lookahead_stats.empty()
+          ? 0
+          : lookahead_stats[0].gop_struct[0].show_frame_count;
+  if (stats_size < this_gop_len + next_gop_len) {
+    Status status;
+    status.code = AOM_CODEC_INVALID_PARAM;
+    status.message = "The firstpass info length is insufficient.";
+    return status;
+  }
+
+  GopEncodeInfo gop_encode_info;
+  const int frame_count = static_cast<int>(gop_struct.gop_frame_list.size());
+  const int active_worst_quality = rc_param_.base_q_index;
+  int active_best_quality = rc_param_.base_q_index;
+  for (int i = 0; i < frame_count; ++i) {
+    FrameEncodeParameters param;
+    const GopFrame &gop_frame = gop_struct.gop_frame_list[i];
+    if (gop_frame.update_type == GopFrameType::kOverlay ||
+        gop_frame.update_type == GopFrameType::kIntermediateOverlay ||
+        gop_frame.update_type == GopFrameType::kRegularLeaf) {
+      param.q_index = rc_param_.base_q_index;
+    } else if (gop_frame.update_type == GopFrameType::kRegularKey) {
+      // Accumulate correlation coefficients to determine KF boost
+      double boost = 0.0;
+      double coeff_kf = 1.0;
+      for (int i = 1; i < stats_size; ++i) {
+        if (CheckFlash(analyzed_fp_info.stats_list, i)) continue;
+        coeff_kf *= analyzed_fp_info.stats_list[i].cor_coeff;
+        const double this_cor =
+            coeff_kf *
+            sqrt(std::max((analyzed_fp_info.stats_list[i].intra_error -
+                           analyzed_fp_info.stats_list[i].noise_var) /
+                              analyzed_fp_info.stats_list[i].intra_error,
+                          0.5));
+        boost += this_cor;
+      }
+      boost = std::min(std::max(sqrt(boost), 1.0), 6.0);
+      const double qstep_ratio = 1.0 / boost;
+      param.q_index = av1_get_q_index_from_qstep_ratio(rc_param_.base_q_index,
+                                                       qstep_ratio, AOM_BITS_8);
+      if (rc_param_.base_q_index) param.q_index = std::max(param.q_index, 1);
+      active_best_quality = param.q_index;
+    } else if (gop_frame.update_type == GopFrameType::kRegularGolden ||
+               gop_frame.update_type == GopFrameType::kRegularArf) {
+      double boost = 0.0;
+
+      // Check the influence of this arf frame to the frames before it
+      for (int f = this_gop_len - 1; f > 0; --f) {
+        // The contribution of this arf to frame f
+        double coeff_this = 1.0;
+        for (int k = this_gop_len; k > f; --k) {
+          if (CheckFlash(analyzed_fp_info.stats_list, k)) continue;
+          coeff_this *= analyzed_fp_info.stats_list[k].cor_coeff;
+        }
+
+        // The contribution of last arf to frame f
+        double coeff_last = 1.0;
+        for (int k = 1; k <= f; ++k) {
+          if (CheckFlash(analyzed_fp_info.stats_list, k)) continue;
+          coeff_last *= analyzed_fp_info.stats_list[k].cor_coeff;
+        }
+
+        if (coeff_last > coeff_this) break;
+
+        // If this is a flash, although we ignore it in the accumulation, we
+        // still count it for this frame so it will probably have a low
+        // correlation
+        if (analyzed_fp_info.stats_list[f].is_flash)
+          coeff_this *= analyzed_fp_info.stats_list[f].cor_coeff;
+
+        const double this_cor =
+            coeff_this *
+            sqrt(std::max((analyzed_fp_info.stats_list[f].intra_error -
+                           analyzed_fp_info.stats_list[f].noise_var) /
+                              analyzed_fp_info.stats_list[f].intra_error,
+                          0.5));
+        boost += this_cor;
+      }
+
+      // Check the influence of this arf frame to the frames after it
+      for (int f = this_gop_len + 1; f < this_gop_len + next_gop_len; ++f) {
+        // The contribution of this arf to frame f
+        double coeff_this = 1.0;
+        for (int k = this_gop_len + 1; k <= f; ++k) {
+          if (CheckFlash(analyzed_fp_info.stats_list, k)) continue;
+          coeff_this *= analyzed_fp_info.stats_list[k].cor_coeff;
+        }
+
+        if (next_gop_len >= 4) {
+          // The contribution of next arf to frame f
+          double coeff_next = 1.0;
+          for (int k = this_gop_len + next_gop_len; k > f; --k) {
+            if (CheckFlash(analyzed_fp_info.stats_list, k)) continue;
+            coeff_next *= analyzed_fp_info.stats_list[k].cor_coeff;
+          }
+          if (coeff_next > coeff_this) break;
+        }
+
+        // If this is a flash, although we ignore it in the accumulation, we
+        // still count it for this frame so it will probably have a low
+        // correlation
+        if (analyzed_fp_info.stats_list[f].is_flash)
+          coeff_this *= analyzed_fp_info.stats_list[f].cor_coeff;
+
+        const double this_cor =
+            coeff_this *
+            sqrt(std::max((analyzed_fp_info.stats_list[f].intra_error -
+                           analyzed_fp_info.stats_list[f].noise_var) /
+                              analyzed_fp_info.stats_list[f].intra_error,
+                          0.5));
+
+        boost += this_cor;
+      }
+      boost = std::min(std::max(sqrt(boost), 1.0), 4.0);
+      const double qstep_ratio = 1.0 / boost;
+      param.q_index = av1_get_q_index_from_qstep_ratio(rc_param_.base_q_index,
+                                                       qstep_ratio, AOM_BITS_8);
+      if (rc_param_.base_q_index) param.q_index = std::max(param.q_index, 1);
+      active_best_quality = param.q_index;
+    } else {
+      // Intermediate ARFs
+      assert(gop_frame.layer_depth >= 1);
+      const int depth_factor = 1 << (gop_frame.layer_depth - 1);
+      param.q_index =
+          (active_worst_quality * (depth_factor - 1) + active_best_quality) /
+          depth_factor;
+    }
+    param.rdmult = GetRDMult(gop_frame, param.q_index);
+    gop_encode_info.param_list.push_back(param);
+  }
+  return gop_encode_info;
 }
 
 StatusOr<GopEncodeInfo> AV1RateControlQMode::GetGopEncodeInfoWithTpl(
@@ -1586,8 +1731,9 @@
 }
 
 StatusOr<GopEncodeInfo> AV1RateControlQMode::GetTplPassGopEncodeInfo(
-    const GopStruct &gop_struct, const FirstpassInfo &firstpass_info) {
-  return GetGopEncodeInfoWithFp(gop_struct, firstpass_info);
+    const GopStruct &gop_struct,
+    const FirstpassInfo &firstpass_info AOM_UNUSED) {
+  return GetGopEncodeInfoWithNoStats(gop_struct);
 }
 
 StatusOr<GopEncodeInfo> AV1RateControlQMode::GetGopEncodeInfo(
diff --git a/av1/qmode_rc/ratectrl_qmode.h b/av1/qmode_rc/ratectrl_qmode.h
index 70e112a..be99db2 100644
--- a/av1/qmode_rc/ratectrl_qmode.h
+++ b/av1/qmode_rc/ratectrl_qmode.h
@@ -136,7 +136,8 @@
   StatusOr<GopEncodeInfo> GetGopEncodeInfoWithNoStats(
       const GopStruct &gop_struct);
   StatusOr<GopEncodeInfo> GetGopEncodeInfoWithFp(
-      const GopStruct &gop_struct, const FirstpassInfo &firstpass_info);
+      const GopStruct &gop_struct, const FirstpassInfo &firstpass_info,
+      const std::vector<LookaheadStats> &lookahead_stats);
   StatusOr<GopEncodeInfo> GetGopEncodeInfoWithTpl(
       const GopStruct &gop_struct, const TplGopStats &tpl_gop_stats,
       const std::vector<LookaheadStats> &lookahead_stats,