Add DetectGopCut() to PartitionGopIntervals()

Bug: b/221916304
Change-Id: Ic7c1926f81289b83946725650a3fff54968c02a1
diff --git a/av1/encoder/pass2_strategy.c b/av1/encoder/pass2_strategy.c
index 9e086ee..b495091 100644
--- a/av1/encoder/pass2_strategy.c
+++ b/av1/encoder/pass2_strategy.c
@@ -512,12 +512,12 @@
   gf_stats->gf_group_inactive_zone_rows += stats->inactive_zone_rows;
 }
 
-static void accumulate_next_frame_stats(const FIRSTPASS_STATS *stats,
-                                        const int flash_detected,
-                                        const int frames_since_key,
-                                        const int cur_idx,
-                                        GF_GROUP_STATS *gf_stats, int f_w,
-                                        int f_h) {
+void av1_accumulate_next_frame_stats(const FIRSTPASS_STATS *stats,
+                                     const int flash_detected,
+                                     const int frames_since_key,
+                                     const int cur_idx,
+                                     GF_GROUP_STATS *gf_stats, int f_w,
+                                     int f_h) {
   accumulate_frame_motion_stats(stats, gf_stats, f_w, f_h);
   // sum up the metric values of current gf group
   gf_stats->avg_sr_coded_error += stats->sr_coded_error;
@@ -1929,8 +1929,9 @@
       flash_detected = detect_flash(twopass, &cpi->twopass_frame, 0);
       // TODO(bohanli): remove redundant accumulations here, or unify
       // this and the ones in define_gf_group
-      accumulate_next_frame_stats(&next_frame, flash_detected,
-                                  rc->frames_since_key, i, &gf_stats, f_w, f_h);
+      av1_accumulate_next_frame_stats(&next_frame, flash_detected,
+                                      rc->frames_since_key, i, &gf_stats, f_w,
+                                      f_h);
 
       cut_here = detect_gf_cut(cpi, i, cur_start, flash_detected,
                                active_max_gf_interval, active_min_gf_interval,
@@ -2266,8 +2267,9 @@
     flash_detected = detect_flash(twopass, &cpi->twopass_frame, 0);
 
     // accumulate stats for next frame
-    accumulate_next_frame_stats(next_frame, flash_detected,
-                                rc->frames_since_key, i, gf_stats, f_w, f_h);
+    av1_accumulate_next_frame_stats(next_frame, flash_detected,
+                                    rc->frames_since_key, i, gf_stats, f_w,
+                                    f_h);
 
     ++i;
   }
diff --git a/av1/encoder/pass2_strategy.h b/av1/encoder/pass2_strategy.h
index c54b8c4..6234623 100644
--- a/av1/encoder/pass2_strategy.h
+++ b/av1/encoder/pass2_strategy.h
@@ -135,6 +135,12 @@
                        int *num_fpstats_used, int *num_fpstats_required,
                        int project_gfu_boost);
 
+void av1_accumulate_next_frame_stats(const FIRSTPASS_STATS *stats,
+                                     const int flash_detected,
+                                     const int frames_since_key,
+                                     const int cur_idx,
+                                     GF_GROUP_STATS *gf_stats, int f_w,
+                                     int f_h);
 // Identify stable and unstable regions from first pass stats.
 // stats_start points to the first frame to analyze.
 // |offset| is the offset from the current frame to the frame stats_start is
diff --git a/av1/ratectrl_qmode.cc b/av1/ratectrl_qmode.cc
index 666a5e3..6ae3c3e 100644
--- a/av1/ratectrl_qmode.cc
+++ b/av1/ratectrl_qmode.cc
@@ -422,6 +422,24 @@
   return -1;
 }
 
+// This function detects a flash through the high relative pcnt_second_ref
+// score in the frame following a flash frame. The offset passed in should
+// reflect this.
+static bool DetectFlash(const std::vector<FIRSTPASS_STATS> &stats_list,
+                        int index) {
+  int next_index = index + 1;
+  if (next_index >= static_cast<int>(stats_list.size())) return false;
+  const FIRSTPASS_STATS &next_frame = stats_list[next_index];
+
+  // What we are looking for here is a situation where there is a
+  // brief break in prediction (such as a flash) but subsequent frames
+  // are reasonably well predicted by an earlier (pre flash) frame.
+  // The recovery after a flash is indicated by a high pcnt_second_ref
+  // compared to pcnt_inter.
+  return next_frame.pcnt_second_ref > next_frame.pcnt_inter &&
+         next_frame.pcnt_second_ref >= 0.5;
+}
+
 #define MIN_SHRINK_LEN 6
 
 // This function takes in a suggesting gop interval from cur_start to cur_last,
@@ -582,6 +600,76 @@
   return cur_last;
 }
 
+// Function to test for a condition where a complex transition is followed
+// by a static section. For example in slide shows where there is a fade
+// between slides. This is to help with more optimal kf and gf positioning.
+static bool DetectTransitionToStill(
+    const std::vector<FIRSTPASS_STATS> &stats_list, int next_stats_index,
+    int min_gop_show_frame_count, int frame_interval, int still_interval,
+    double loop_decay_rate, double last_decay_rate) {
+  // Break clause to detect very still sections after motion
+  // For example a static image after a fade or other transition
+  // instead of a clean scene cut.
+  if (frame_interval > min_gop_show_frame_count && loop_decay_rate >= 0.999 &&
+      last_decay_rate < 0.9) {
+    int stats_count = static_cast<int>(stats_list.size());
+    int stats_left = stats_count - next_stats_index;
+    if (stats_left >= still_interval) {
+      // Look ahead a few frames to see if static condition persists...
+      int j;
+      for (j = 0; j < still_interval; ++j) {
+        const FIRSTPASS_STATS &stats = stats_list[next_stats_index + j];
+        if (stats.pcnt_inter - stats.pcnt_motion < 0.999) break;
+      }
+      // Only if it does do we signal a transition to still.
+      return j == still_interval;
+    }
+  }
+  return false;
+}
+
+static int DetectGopCut(const std::vector<FIRSTPASS_STATS> &stats_list,
+                        int start_idx, int candidate_cut_idx, int next_key_idx,
+                        int flash_detected, int min_gop_show_frame_count,
+                        int max_gop_show_frame_count, int frame_width,
+                        int frame_height, const GF_GROUP_STATS &gf_stats) {
+  (void)max_gop_show_frame_count;
+  const int candidate_gop_size = candidate_cut_idx - start_idx;
+
+  if (!flash_detected) {
+    // Break clause to detect very still sections after motion. For example,
+    // a static image after a fade or other transition.
+    if (DetectTransitionToStill(stats_list, start_idx, min_gop_show_frame_count,
+                                candidate_gop_size, 5, gf_stats.loop_decay_rate,
+                                gf_stats.last_loop_decay_rate)) {
+      return 1;
+    }
+    const double arf_abs_zoom_thresh = 4.4;
+    // Motion breakout threshold for loop below depends on image size.
+    const double mv_ratio_accumulator_thresh =
+        (frame_height + frame_width) / 4.0;
+    // Some conditions to breakout after min interval.
+    if (candidate_gop_size >= min_gop_show_frame_count &&
+        // If possible don't break very close to a kf
+        (next_key_idx - candidate_cut_idx >= min_gop_show_frame_count) &&
+        (candidate_gop_size & 0x01) &&
+        (gf_stats.mv_ratio_accumulator > mv_ratio_accumulator_thresh ||
+         gf_stats.abs_mv_in_out_accumulator > arf_abs_zoom_thresh)) {
+      return 1;
+    }
+  }
+
+  // TODO(b/231489624): Check if we need this part.
+  // If almost totally static, we will not use the the max GF length later,
+  // so we can continue for more frames.
+  // if ((candidate_gop_size >= active_max_gf_interval + 1) &&
+  //     !is_almost_static(gf_stats->zero_motion_accumulator,
+  //                       twopass->kf_zeromotion_pct, cpi->ppi->lap_enabled)) {
+  //   return 0;
+  // }
+  return 0;
+}
+
 /*!\brief Determine the length of future GF groups.
  *
  * \ingroup gf_group_algo
@@ -612,7 +700,6 @@
   GF_GROUP_STATS gf_stats;
   InitGFStats(&gf_stats);
   int num_stats = static_cast<int>(stats_list.size());
-  int stats_in_loop_index = order_index;
   while (i + order_index < num_stats) {
     // reaches next key frame, break here
     if (i >= frames_to_key) {
@@ -621,9 +708,29 @@
       // reached maximum len, but nothing special yet (almost static)
       // let's look at the next interval
       cut_here = 1;
-    } else if (stats_in_loop_index >= num_stats) {
-      // reaches last frame, break
-      cut_here = 2;
+    } else {
+      // Test for the case where there is a brief flash but the prediction
+      // quality back to an earlier frame is then restored.
+      const int gop_start_idx = cur_start + order_index;
+      const int candidate_gop_cut_idx = i + order_index;
+      const int next_key_idx = frames_to_key + order_index;
+      const bool flash_detected =
+          DetectFlash(stats_list, candidate_gop_cut_idx);
+
+      // TODO(bohanli): remove redundant accumulations here, or unify
+      // this and the ones in define_gf_group
+      const FIRSTPASS_STATS *stats = &stats_list[candidate_gop_cut_idx];
+      av1_accumulate_next_frame_stats(stats, flash_detected, frames_since_key,
+                                      i, &gf_stats, rc_param.frame_width,
+                                      rc_param.frame_height);
+
+      // TODO(angiebird): Can we simplify this part? Looks like we are going to
+      // change the gop cut index with FindBetterGopCut() anyway.
+      cut_here = DetectGopCut(
+          stats_list, gop_start_idx, candidate_gop_cut_idx, next_key_idx,
+          flash_detected, rc_param.min_gop_show_frame_count,
+          rc_param.max_gop_show_frame_count, rc_param.frame_width,
+          rc_param.frame_height, gf_stats);
     }
 
     if (!cut_here) {
@@ -639,13 +746,13 @@
     cut_pos.push_back(cur_last);
 
     // reset pointers to the shrunken location
-    stats_in_loop_index = order_index + cur_last;
     cur_start = cur_last;
     int cur_region_idx =
         FindRegionIndex(regions_list, cur_start + 1 + frames_since_key);
     if (cur_region_idx >= 0)
       if (regions_list[cur_region_idx].type == SCENECUT_REGION) cur_start++;
 
+    // TODO(angiebird): Why do we need to break here?
     if (cut_here > 1 && cur_last == original_last) break;
     // reset accumulators
     InitGFStats(&gf_stats);
diff --git a/av1/ratectrl_qmode_interface.h b/av1/ratectrl_qmode_interface.h
index b06e182..5fa8492 100644
--- a/av1/ratectrl_qmode_interface.h
+++ b/av1/ratectrl_qmode_interface.h
@@ -33,6 +33,8 @@
   int min_gop_show_frame_count;
   int max_ref_frames;
   int base_q_index;
+  int frame_width;
+  int frame_height;
 };
 
 struct TplBlockStats {