Extra boost when the arf is in a static region.

This CL also factors out the part to calculate arf boost and average
correlation coefficient, so it can be used when determining
frame-level qp based on tpl stats.

BUG: b/260883367
BUG: b/260859962

Change-Id: I9baf11993d43ca682bce80f1b79ff4a06be23914
diff --git a/av1/qmode_rc/ratectrl_qmode.cc b/av1/qmode_rc/ratectrl_qmode.cc
index af01a18..28cd7a6 100644
--- a/av1/qmode_rc/ratectrl_qmode.cc
+++ b/av1/qmode_rc/ratectrl_qmode.cc
@@ -1571,6 +1571,133 @@
   return stats_list[index].is_flash || stats_list[index - 1].is_flash;
 }
 
+void inline SetUpFrameIndices(GopFrameType update_type, int stats_size,
+                              int this_gop_len, int next_gop_len,
+                              int &this_index, int &first_index,
+                              int &last_index, int &ref_before_index,
+                              int &ref_after_index) {
+  if (update_type == GopFrameType::kRegularKey) {
+    this_index = 0;
+    first_index = 1;
+    last_index = stats_size - 1;
+    ref_before_index = -1;
+    ref_after_index = -1;
+  } else if (update_type == GopFrameType::kRegularGolden) {
+    // TODO(b/260859962): Need to consider the situation when arf is not
+    // used
+    this_index = 0;
+    first_index = 1;
+    last_index = this_gop_len - 2;
+    ref_before_index = -1;
+    ref_after_index = this_gop_len - 1;
+  } else {
+    // arf type
+    // TODO(b/260859962): It looks like in this case the last arf should
+    // actually be at index -1. This for now should be accurate enough, but
+    // in the future it is better to have the exact index of last arf.
+    this_index = this_gop_len - 1;
+    first_index = 1;
+    last_index = next_gop_len >= 4 ? this_gop_len + next_gop_len - 2
+                                   : this_gop_len + next_gop_len - 1;
+    ref_before_index = 0;
+    ref_after_index = next_gop_len >= 4 ? this_gop_len + next_gop_len - 1 : -1;
+  }
+}
+
+// Return the accumulated score of a frame, considering its influence on the
+// frames from first_index to last_index (both inclusive). When ref_before_index
+// >= 0, only consider the frames where the current frame has a larger
+// correlation than the frame at ref_before_index. Same for ref_after_index.
+// This function also calculates and returns the average correlation coefficient
+// of this frame to the affected frames through the parameter avg_correlation.
+double GetAccumulatedScore(const FirstpassInfo &firstpass_info, int this_index,
+                           int first_index, int last_index,
+                           int ref_before_index, int ref_after_index,
+                           double &avg_correlation) {
+  assert(ref_before_index < 0 || ref_before_index < first_index);
+  assert(ref_after_index < 0 || ref_after_index > last_index);
+  double score = 0.0;
+  int count = 0;
+  avg_correlation = 0.0;
+  // Check the influence of this frame to the frames before it
+  for (int f = this_index - 1; f >= first_index; --f) {
+    // The contribution of this frame to frame f
+    double coeff_this = 1.0;
+    for (int k = this_index; k > f; --k) {
+      if (CheckFlash(firstpass_info.stats_list, k)) continue;
+      coeff_this *= firstpass_info.stats_list[k].cor_coeff;
+    }
+    // The contribution of frame at ref_before_index to frame f
+    if (ref_before_index >= 0) {
+      double coeff_last = 1.0;
+      for (int k = ref_before_index + 1; k <= f; ++k) {
+        if (CheckFlash(firstpass_info.stats_list, k)) continue;
+        coeff_last *= firstpass_info.stats_list[k].cor_coeff;
+      }
+      if (coeff_last > coeff_this) break;
+    }
+    ++count;
+    avg_correlation += firstpass_info.stats_list[f + 1].cor_coeff;
+
+    // If this is a flash, although we ignore it in the accumulation, we
+    // still count it for this frame so it will probably have a low
+    // correlation
+    if (firstpass_info.stats_list[f].is_flash)
+      coeff_this *= firstpass_info.stats_list[f].cor_coeff;
+
+    const double this_cor =
+        coeff_this * sqrt(std::max((firstpass_info.stats_list[f].intra_error -
+                                    firstpass_info.stats_list[f].noise_var) /
+                                       firstpass_info.stats_list[f].intra_error,
+                                   0.5));
+    score += this_cor;
+  }
+
+  // Check the influence of this frame to the frames after it
+  for (int f = this_index + 1; f <= last_index; ++f) {
+    // The contribution of this frame to frame f
+    double coeff_this = 1.0;
+    for (int k = this_index + 1; k <= f; ++k) {
+      if (CheckFlash(firstpass_info.stats_list, k)) continue;
+      coeff_this *= firstpass_info.stats_list[k].cor_coeff;
+    }
+
+    // The contribution of frame at ref_after_index to frame f
+    if (ref_after_index >= 0) {
+      double coeff_next = 1.0;
+      for (int k = ref_after_index; k > f; --k) {
+        if (CheckFlash(firstpass_info.stats_list, k)) continue;
+        coeff_next *= firstpass_info.stats_list[k].cor_coeff;
+      }
+      if (coeff_next > coeff_this) break;
+    }
+    ++count;
+    avg_correlation += firstpass_info.stats_list[f].cor_coeff;
+
+    // If this is a flash, although we ignore it in the accumulation, we
+    // still count it for this frame so it will probably have a low
+    // correlation
+    if (firstpass_info.stats_list[f].is_flash)
+      coeff_this *= firstpass_info.stats_list[f].cor_coeff;
+
+    const double this_cor =
+        coeff_this * sqrt(std::max((firstpass_info.stats_list[f].intra_error -
+                                    firstpass_info.stats_list[f].noise_var) /
+                                       firstpass_info.stats_list[f].intra_error,
+                                   0.5));
+    score += this_cor;
+  }
+  if (count > 0) avg_correlation /= static_cast<double>(count);
+  return score;
+}
+
+int AdjustStaticQp(double avg_correlation, double score, int q_index) {
+  if (avg_correlation < 0.99) return q_index;
+  const double factor = q_index * score / 400 + 1.0;
+
+  return static_cast<int>(q_index / factor);
+}
+
 StatusOr<GopEncodeInfo> AV1RateControlQMode::GetGopEncodeInfoWithFp(
     const GopStruct &gop_struct, const FirstpassInfo &firstpass_info,
     const std::vector<LookaheadStats> &lookahead_stats,
@@ -1607,102 +1734,27 @@
         gop_frame.update_type == GopFrameType::kIntermediateOverlay ||
         gop_frame.update_type == GopFrameType::kRegularLeaf) {
       param.q_index = rc_param_.base_q_index;
-    } else if (gop_frame.update_type == GopFrameType::kRegularKey) {
-      // Accumulate correlation coefficients to determine KF boost
-      double boost = 0.0;
-      double coeff_kf = 1.0;
-      for (int i = 1; i < stats_size; ++i) {
-        if (CheckFlash(analyzed_fp_info.stats_list, i)) continue;
-        coeff_kf *= analyzed_fp_info.stats_list[i].cor_coeff;
-        const double this_cor =
-            coeff_kf *
-            sqrt(std::max((analyzed_fp_info.stats_list[i].intra_error -
-                           analyzed_fp_info.stats_list[i].noise_var) /
-                              analyzed_fp_info.stats_list[i].intra_error,
-                          0.5));
-        boost += this_cor;
-      }
-      boost = std::min(std::max(sqrt(boost), 1.0), 6.0);
-      const double qstep_ratio = 1.0 / boost;
-      param.q_index = av1_get_q_index_from_qstep_ratio(rc_param_.base_q_index,
-                                                       qstep_ratio, AOM_BITS_8);
-      if (rc_param_.base_q_index) param.q_index = std::max(param.q_index, 1);
-      active_best_quality = param.q_index;
-    } else if (gop_frame.update_type == GopFrameType::kRegularGolden ||
+    } else if (gop_frame.update_type == GopFrameType::kRegularKey ||
+               gop_frame.update_type == GopFrameType::kRegularGolden ||
                gop_frame.update_type == GopFrameType::kRegularArf) {
-      double boost = 0.0;
+      int this_index, first_index, last_index, ref_before_index,
+          ref_after_index;
+      SetUpFrameIndices(gop_frame.update_type, stats_size, this_gop_len,
+                        next_gop_len, this_index, first_index, last_index,
+                        ref_before_index, ref_after_index);
 
-      // Check the influence of this arf frame to the frames before it
-      for (int f = this_gop_len - 2; f > 0; --f) {
-        // The contribution of this arf to frame f
-        double coeff_this = 1.0;
-        for (int k = this_gop_len - 1; k > f; --k) {
-          if (CheckFlash(analyzed_fp_info.stats_list, k)) continue;
-          coeff_this *= analyzed_fp_info.stats_list[k].cor_coeff;
-        }
-
-        // The contribution of last arf to frame f
-        double coeff_last = 1.0;
-        for (int k = 1; k <= f; ++k) {
-          if (CheckFlash(analyzed_fp_info.stats_list, k)) continue;
-          coeff_last *= analyzed_fp_info.stats_list[k].cor_coeff;
-        }
-
-        if (coeff_last > coeff_this) break;
-
-        // If this is a flash, although we ignore it in the accumulation, we
-        // still count it for this frame so it will probably have a low
-        // correlation
-        if (analyzed_fp_info.stats_list[f].is_flash)
-          coeff_this *= analyzed_fp_info.stats_list[f].cor_coeff;
-
-        const double this_cor =
-            coeff_this *
-            sqrt(std::max((analyzed_fp_info.stats_list[f].intra_error -
-                           analyzed_fp_info.stats_list[f].noise_var) /
-                              analyzed_fp_info.stats_list[f].intra_error,
-                          0.5));
-        boost += this_cor;
-      }
-
-      // Check the influence of this arf frame to the frames after it
-      for (int f = this_gop_len; f < this_gop_len + next_gop_len; ++f) {
-        // The contribution of this arf to frame f
-        double coeff_this = 1.0;
-        for (int k = this_gop_len; k <= f; ++k) {
-          if (CheckFlash(analyzed_fp_info.stats_list, k)) continue;
-          coeff_this *= analyzed_fp_info.stats_list[k].cor_coeff;
-        }
-
-        if (next_gop_len >= 4) {
-          // The contribution of next arf to frame f
-          double coeff_next = 1.0;
-          for (int k = this_gop_len + next_gop_len - 1; k > f; --k) {
-            if (CheckFlash(analyzed_fp_info.stats_list, k)) continue;
-            coeff_next *= analyzed_fp_info.stats_list[k].cor_coeff;
-          }
-          if (coeff_next > coeff_this) break;
-        }
-
-        // If this is a flash, although we ignore it in the accumulation, we
-        // still count it for this frame so it will probably have a low
-        // correlation
-        if (analyzed_fp_info.stats_list[f].is_flash)
-          coeff_this *= analyzed_fp_info.stats_list[f].cor_coeff;
-
-        const double this_cor =
-            coeff_this *
-            sqrt(std::max((analyzed_fp_info.stats_list[f].intra_error -
-                           analyzed_fp_info.stats_list[f].noise_var) /
-                              analyzed_fp_info.stats_list[f].intra_error,
-                          0.5));
-
-        boost += this_cor;
-      }
-      boost = std::min(std::max(sqrt(boost), 1.0), 4.0);
+      double avg_correlation = 0;
+      const double score = GetAccumulatedScore(
+          analyzed_fp_info, this_index, first_index, last_index,
+          ref_before_index, ref_after_index, avg_correlation);
+      const double boost = std::min(
+          std::max(sqrt(score), 1.0),
+          gop_frame.update_type == GopFrameType::kRegularKey ? 6.0 : 4.0);
       const double qstep_ratio = 1.0 / boost;
       param.q_index = av1_get_q_index_from_qstep_ratio(rc_param_.base_q_index,
                                                        qstep_ratio, AOM_BITS_8);
+      param.q_index = AdjustStaticQp(avg_correlation, score, param.q_index);
+
       if (rc_param_.base_q_index) param.q_index = std::max(param.q_index, 1);
       active_best_quality = param.q_index;
 
@@ -1721,7 +1773,8 @@
 }
 
 StatusOr<GopEncodeInfo> AV1RateControlQMode::GetGopEncodeInfoWithTpl(
-    const GopStruct &gop_struct, const TplGopStats &tpl_gop_stats,
+    const GopStruct &gop_struct, const FirstpassInfo &firstpass_info,
+    const TplGopStats &tpl_gop_stats,
     const std::vector<LookaheadStats> &lookahead_stats,
     const RefFrameTable &ref_frame_table_snapshot_init) {
   const std::vector<RefFrameTable> ref_frame_table_list = GetRefFrameTableList(
@@ -1736,6 +1789,23 @@
   }
   const int frame_count =
       static_cast<int>(tpl_gop_stats.frame_stats_list.size());
+
+  const int stats_size = static_cast<int>(firstpass_info.stats_list.size());
+  const FirstpassInfo analyzed_fp_info =
+      AnalyzeFpStats(std::move(firstpass_info));
+
+  const int this_gop_len = gop_struct.show_frame_count;
+  const int next_gop_len =
+      lookahead_stats.empty()
+          ? 0
+          : lookahead_stats[0].gop_struct[0].show_frame_count;
+  if (stats_size < this_gop_len + next_gop_len) {
+    Status status;
+    status.code = AOM_CODEC_INVALID_PARAM;
+    status.message = "The firstpass info length is insufficient.";
+    return status;
+  }
+
   const int active_worst_quality = rc_param_.base_q_index;
   int active_best_quality = rc_param_.base_q_index;
 
@@ -1796,6 +1866,18 @@
       const double qstep_ratio = sqrt(1 / frame_importance);
       param.q_index = av1_get_q_index_from_qstep_ratio(rc_param_.base_q_index,
                                                        qstep_ratio, AOM_BITS_8);
+      int this_index, first_index, last_index, ref_before_index,
+          ref_after_index;
+      SetUpFrameIndices(gop_frame.update_type, stats_size, this_gop_len,
+                        next_gop_len, this_index, first_index, last_index,
+                        ref_before_index, ref_after_index);
+
+      double avg_correlation = 0;
+      const double score = GetAccumulatedScore(
+          analyzed_fp_info, this_index, first_index, last_index,
+          ref_before_index, ref_after_index, avg_correlation);
+      param.q_index = AdjustStaticQp(avg_correlation, score, param.q_index);
+
       if (rc_param_.base_q_index) param.q_index = AOMMAX(param.q_index, 1);
       active_best_quality = param.q_index;
 
@@ -1857,10 +1939,8 @@
     }
   }
 
-  // TODO(b/260859962): Currently firstpass stats are used as an alternative,
-  // but we could also combine it with tpl results in the future for more
-  // stable qp determination.
-  return GetGopEncodeInfoWithTpl(gop_struct, tpl_gop_stats, lookahead_stats,
+  return GetGopEncodeInfoWithTpl(gop_struct, firstpass_info, tpl_gop_stats,
+                                 lookahead_stats,
                                  ref_frame_table_snapshot_init);
 }
 
diff --git a/av1/qmode_rc/ratectrl_qmode.h b/av1/qmode_rc/ratectrl_qmode.h
index 86e15b0..4f2180c 100644
--- a/av1/qmode_rc/ratectrl_qmode.h
+++ b/av1/qmode_rc/ratectrl_qmode.h
@@ -142,7 +142,8 @@
       const std::vector<LookaheadStats> &lookahead_stats,
       const RefFrameTable &ref_frame_table_snapshot_init);
   StatusOr<GopEncodeInfo> GetGopEncodeInfoWithTpl(
-      const GopStruct &gop_struct, const TplGopStats &tpl_gop_stats,
+      const GopStruct &gop_struct, const FirstpassInfo &firstpass_info,
+      const TplGopStats &tpl_gop_stats,
       const std::vector<LookaheadStats> &lookahead_stats,
       const RefFrameTable &ref_frame_table_snapshot_init);
 };