/*
 * Copyright (c) 2022, Alliance for Open Media. All rights reserved
 *
 * This source code is subject to the terms of the BSD 2 Clause License and
 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
 * was not distributed with this source code in the LICENSE file, you can
 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
 * Media Patent License 1.0 was not distributed with this source code in the
 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
 */
#include "av1/qmode_rc/ratectrl_qmode.h"

#include <algorithm>
#include <cassert>
#include <climits>
#include <functional>
#include <numeric>
#include <sstream>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#include "aom/aom_codec.h"
#include "av1/encoder/pass2_strategy.h"
#include "av1/encoder/ratectrl.h"
#include "av1/encoder/tpl_model.h"

namespace aom {

// This is used before division to ensure that the divisor isn't zero or
// too close to zero.
static double ModifyDivisor(double divisor) {
  const double kEpsilon = 0.0000001;
  return (divisor < 0 ? std::min(divisor, -kEpsilon)
                      : std::max(divisor, kEpsilon));
}

GopFrame GopFrameInvalid() {
  GopFrame gop_frame = {};
  gop_frame.is_valid = false;
  gop_frame.coding_idx = -1;
  gop_frame.order_idx = -1;
  return gop_frame;
}

void SetGopFrameByType(GopFrameType gop_frame_type, GopFrame *gop_frame) {
  gop_frame->update_type = gop_frame_type;
  switch (gop_frame_type) {
    case GopFrameType::kRegularKey:
      gop_frame->is_key_frame = 1;
      gop_frame->is_arf_frame = 0;
      gop_frame->is_show_frame = 1;
      gop_frame->is_golden_frame = 1;
      gop_frame->encode_ref_mode = EncodeRefMode::kRegular;
      break;
    case GopFrameType::kRegularGolden:
      gop_frame->is_key_frame = 0;
      gop_frame->is_arf_frame = 0;
      gop_frame->is_show_frame = 1;
      gop_frame->is_golden_frame = 1;
      gop_frame->encode_ref_mode = EncodeRefMode::kRegular;
      break;
    case GopFrameType::kRegularArf:
      gop_frame->is_key_frame = 0;
      gop_frame->is_arf_frame = 1;
      gop_frame->is_show_frame = 0;
      gop_frame->is_golden_frame = 1;
      gop_frame->encode_ref_mode = EncodeRefMode::kRegular;
      break;
    case GopFrameType::kIntermediateArf:
      gop_frame->is_key_frame = 0;
      gop_frame->is_arf_frame = 1;
      gop_frame->is_show_frame = 0;
      gop_frame->is_golden_frame = gop_frame->layer_depth <= 2 ? 1 : 0;
      gop_frame->encode_ref_mode = EncodeRefMode::kRegular;
      break;
    case GopFrameType::kRegularLeaf:
      gop_frame->is_key_frame = 0;
      gop_frame->is_arf_frame = 0;
      gop_frame->is_show_frame = 1;
      gop_frame->is_golden_frame = 0;
      gop_frame->encode_ref_mode = EncodeRefMode::kRegular;
      break;
    case GopFrameType::kIntermediateOverlay:
      gop_frame->is_key_frame = 0;
      gop_frame->is_arf_frame = 0;
      gop_frame->is_show_frame = 1;
      gop_frame->is_golden_frame = 0;
      gop_frame->encode_ref_mode = EncodeRefMode::kShowExisting;
      break;
    case GopFrameType::kOverlay:
      gop_frame->is_key_frame = 0;
      gop_frame->is_arf_frame = 0;
      gop_frame->is_show_frame = 1;
      gop_frame->is_golden_frame = 0;
      gop_frame->encode_ref_mode = EncodeRefMode::kOverlay;
      break;
  }
}

GopFrame GopFrameBasic(int global_coding_idx_offset,
                       int global_order_idx_offset, int coding_idx,
                       int order_idx, int depth, int display_idx,
                       GopFrameType gop_frame_type) {
  GopFrame gop_frame = {};
  gop_frame.is_valid = true;
  gop_frame.coding_idx = coding_idx;
  gop_frame.order_idx = order_idx;
  gop_frame.display_idx = display_idx;
  gop_frame.global_coding_idx = global_coding_idx_offset + coding_idx;
  gop_frame.global_order_idx = global_order_idx_offset + order_idx;
  gop_frame.layer_depth = depth + kLayerDepthOffset;
  gop_frame.colocated_ref_idx = -1;
  gop_frame.update_ref_idx = -1;
  SetGopFrameByType(gop_frame_type, &gop_frame);
  return gop_frame;
}

// This function create gop frames with indices of display order from
// order_start to order_end - 1. The function will recursively introduce
// intermediate ARF untill maximum depth is met or the number of regular frames
// in between two ARFs are less than 3. Than the regular frames will be added
// into the gop_struct.
void ConstructGopMultiLayer(GopStruct *gop_struct,
                            RefFrameManager *ref_frame_manager, int max_depth,
                            int depth, int order_start, int order_end) {
  GopFrame gop_frame;
  int num_frames = order_end - order_start;
  const int global_coding_idx_offset = gop_struct->global_coding_idx_offset;
  const int global_order_idx_offset = gop_struct->global_order_idx_offset;
  // If there are less than kMinIntervalToAddArf frames, stop introducing ARF
  if (depth < max_depth && num_frames >= kMinIntervalToAddArf) {
    int order_mid = (order_start + order_end) / 2;
    // intermediate ARF
    gop_frame = GopFrameBasic(
        global_coding_idx_offset, global_order_idx_offset,
        static_cast<int>(gop_struct->gop_frame_list.size()), order_mid, depth,
        gop_struct->display_tracker, GopFrameType::kIntermediateArf);
    ref_frame_manager->UpdateRefFrameTable(&gop_frame);
    gop_struct->gop_frame_list.push_back(gop_frame);
    ConstructGopMultiLayer(gop_struct, ref_frame_manager, max_depth, depth + 1,
                           order_start, order_mid);
    // show existing intermediate ARF
    gop_frame =
        GopFrameBasic(global_coding_idx_offset, global_order_idx_offset,
                      static_cast<int>(gop_struct->gop_frame_list.size()),
                      order_mid, max_depth, gop_struct->display_tracker,
                      GopFrameType::kIntermediateOverlay);
    ref_frame_manager->UpdateRefFrameTable(&gop_frame);
    gop_struct->gop_frame_list.push_back(gop_frame);
    ++gop_struct->display_tracker;
    ConstructGopMultiLayer(gop_struct, ref_frame_manager, max_depth, depth + 1,
                           order_mid + 1, order_end);
  } else {
    // regular frame
    for (int i = order_start; i < order_end; ++i) {
      gop_frame = GopFrameBasic(
          global_coding_idx_offset, global_order_idx_offset,
          static_cast<int>(gop_struct->gop_frame_list.size()), i, max_depth,
          gop_struct->display_tracker, GopFrameType::kRegularLeaf);
      ref_frame_manager->UpdateRefFrameTable(&gop_frame);
      gop_struct->gop_frame_list.push_back(gop_frame);
      ++gop_struct->display_tracker;
    }
  }
}

GopStruct ConstructGop(RefFrameManager *ref_frame_manager, int show_frame_count,
                       bool has_key_frame, int global_coding_idx_offset,
                       int global_order_idx_offset) {
  GopStruct gop_struct;
  gop_struct.show_frame_count = show_frame_count;
  gop_struct.global_coding_idx_offset = global_coding_idx_offset;
  gop_struct.global_order_idx_offset = global_order_idx_offset;
  int order_start = 0;
  int order_end = show_frame_count - 1;

  // TODO(jingning): Re-enable the use of pyramid coding structure.
  bool has_arf_frame = show_frame_count > kMinIntervalToAddArf;

  gop_struct.display_tracker = 0;
  gop_struct.base_q_ratio = 1.0;

  GopFrame gop_frame;
  if (has_key_frame) {
    const int key_frame_depth = -1;
    ref_frame_manager->Reset();
    gop_frame = GopFrameBasic(
        global_coding_idx_offset, global_order_idx_offset,
        static_cast<int>(gop_struct.gop_frame_list.size()), order_start,
        key_frame_depth, gop_struct.display_tracker, GopFrameType::kRegularKey);
    ref_frame_manager->UpdateRefFrameTable(&gop_frame);
    gop_struct.gop_frame_list.push_back(gop_frame);
    order_start++;
    ++gop_struct.display_tracker;
  }

  const int arf_depth = 0;
  if (has_arf_frame) {
    // Use multi-layer pyrmaid coding structure.
    gop_frame = GopFrameBasic(
        global_coding_idx_offset, global_order_idx_offset,
        static_cast<int>(gop_struct.gop_frame_list.size()), order_end,
        arf_depth, gop_struct.display_tracker, GopFrameType::kRegularArf);
    ref_frame_manager->UpdateRefFrameTable(&gop_frame);
    gop_struct.gop_frame_list.push_back(gop_frame);
    ConstructGopMultiLayer(&gop_struct, ref_frame_manager,
                           ref_frame_manager->MaxRefFrame() - 1, arf_depth + 1,
                           order_start, order_end);
    // Overlay
    gop_frame =
        GopFrameBasic(global_coding_idx_offset, global_order_idx_offset,
                      static_cast<int>(gop_struct.gop_frame_list.size()),
                      order_end, ref_frame_manager->MaxRefFrame() - 1,
                      gop_struct.display_tracker, GopFrameType::kOverlay);
    ref_frame_manager->UpdateRefFrameTable(&gop_frame);
    gop_struct.gop_frame_list.push_back(gop_frame);
    ++gop_struct.display_tracker;
  } else {
    // Use IPPP format.
    for (int i = order_start; i <= order_end; ++i) {
      gop_frame = GopFrameBasic(
          global_coding_idx_offset, global_order_idx_offset,
          static_cast<int>(gop_struct.gop_frame_list.size()), i, arf_depth + 1,
          gop_struct.display_tracker, GopFrameType::kRegularLeaf);
      ref_frame_manager->UpdateRefFrameTable(&gop_frame);
      gop_struct.gop_frame_list.push_back(gop_frame);
      ++gop_struct.display_tracker;
    }
  }

  return gop_struct;
}

Status AV1RateControlQMode::SetRcParam(const RateControlParam &rc_param) {
  std::ostringstream error_message;
  if (rc_param.max_gop_show_frame_count <
      std::max(4, rc_param.min_gop_show_frame_count)) {
    error_message << "max_gop_show_frame_count ("
                  << rc_param.max_gop_show_frame_count
                  << ") must be at least 4 and may not be less than "
                     "min_gop_show_frame_count ("
                  << rc_param.min_gop_show_frame_count << ")";
    return { AOM_CODEC_INVALID_PARAM, error_message.str() };
  }
  if (rc_param.ref_frame_table_size < 1 || rc_param.ref_frame_table_size > 8) {
    error_message << "ref_frame_table_size (" << rc_param.ref_frame_table_size
                  << ") must be in the range [1, 8].";
    return { AOM_CODEC_INVALID_PARAM, error_message.str() };
  }
  if (rc_param.max_ref_frames < 1 || rc_param.max_ref_frames > 7) {
    error_message << "max_ref_frames (" << rc_param.max_ref_frames
                  << ") must be in the range [1, 7].";
    return { AOM_CODEC_INVALID_PARAM, error_message.str() };
  }
  if (rc_param.base_q_index < 0 || rc_param.base_q_index > 255) {
    error_message << "base_q_index (" << rc_param.base_q_index
                  << ") must be in the range [0, 255].";
    return { AOM_CODEC_INVALID_PARAM, error_message.str() };
  }
  if (rc_param.frame_width < 16 || rc_param.frame_width > 16384 ||
      rc_param.frame_height < 16 || rc_param.frame_height > 16384) {
    error_message << "frame_width (" << rc_param.frame_width
                  << ") and frame_height (" << rc_param.frame_height
                  << ") must be in the range [16, 16384].";
    return { AOM_CODEC_INVALID_PARAM, error_message.str() };
  }
  rc_param_ = rc_param;
  return { AOM_CODEC_OK, "" };
}

// Threshold for use of the lagging second reference frame. High second ref
// usage may point to a transient event like a flash or occlusion rather than
// a real scene cut.
// We adapt the threshold based on number of frames in this key-frame group so
// far.
static double GetSecondRefUsageThreshold(int frame_count_so_far) {
  const int adapt_upto = 32;
  const double min_second_ref_usage_thresh = 0.085;
  const double second_ref_usage_thresh_max_delta = 0.035;
  if (frame_count_so_far >= adapt_upto) {
    return min_second_ref_usage_thresh + second_ref_usage_thresh_max_delta;
  }
  return min_second_ref_usage_thresh +
         ((double)frame_count_so_far / (adapt_upto - 1)) *
             second_ref_usage_thresh_max_delta;
}

// Slide show transition detection.
// Tests for case where there is very low error either side of the current frame
// but much higher just for this frame. This can help detect key frames in
// slide shows even where the slides are pictures of different sizes.
// Also requires that intra and inter errors are very similar to help eliminate
// harmful false positives.
// It will not help if the transition is a fade or other multi-frame effect.
static bool DetectSlideTransition(const FIRSTPASS_STATS &this_frame,
                                  const FIRSTPASS_STATS &last_frame,
                                  const FIRSTPASS_STATS &next_frame) {
  // Intra / Inter threshold very low
  constexpr double kVeryLowII = 1.5;
  // Clean slide transitions we expect a sharp single frame spike in error.
  constexpr double kErrorSpike = 5.0;

  // TODO(angiebird): Understand the meaning of these conditions.
  return (this_frame.intra_error < (this_frame.coded_error * kVeryLowII)) &&
         (this_frame.coded_error > (last_frame.coded_error * kErrorSpike)) &&
         (this_frame.coded_error > (next_frame.coded_error * kErrorSpike));
}

// Check if there is a significant intra/inter error change between the current
// frame and its neighbor. If so, we should further test whether the current
// frame should be a key frame.
static bool DetectIntraInterErrorChange(const FIRSTPASS_STATS &this_stats,
                                        const FIRSTPASS_STATS &last_stats,
                                        const FIRSTPASS_STATS &next_stats) {
  // Minimum % intra coding observed in first pass (1.0 = 100%)
  constexpr double kMinIntraLevel = 0.25;
  // Minimum ratio between the % of intra coding and inter coding in the first
  // pass after discounting neutral blocks (discounting neutral blocks in this
  // way helps catch scene cuts in clips with very flat areas or letter box
  // format clips with image padding.
  constexpr double kIntraVsInterRatio = 2.0;

  const double modified_pcnt_inter =
      this_stats.pcnt_inter - this_stats.pcnt_neutral;
  const double pcnt_intra_min =
      std::max(kMinIntraLevel, kIntraVsInterRatio * modified_pcnt_inter);

  // In real scene cuts there is almost always a sharp change in the intra
  // or inter error score.
  constexpr double kErrorChangeThreshold = 0.4;
  const double last_this_error_ratio =
      fabs(last_stats.coded_error - this_stats.coded_error) /
      ModifyDivisor(this_stats.coded_error);

  const double this_next_error_ratio =
      fabs(last_stats.intra_error - this_stats.intra_error) /
      ModifyDivisor(this_stats.intra_error);

  // Maximum threshold for the relative ratio of intra error score vs best
  // inter error score.
  constexpr double kThisIntraCodedErrorRatioMax = 1.9;
  const double this_intra_coded_error_ratio =
      this_stats.intra_error / ModifyDivisor(this_stats.coded_error);

  // For real scene cuts we expect an improvment in the intra inter error
  // ratio in the next frame.
  constexpr double kNextIntraCodedErrorRatioMin = 3.5;
  const double next_intra_coded_error_ratio =
      next_stats.intra_error / ModifyDivisor(next_stats.coded_error);

  double pcnt_intra = 1.0 - this_stats.pcnt_inter;
  return pcnt_intra > pcnt_intra_min &&
         this_intra_coded_error_ratio < kThisIntraCodedErrorRatioMax &&
         (last_this_error_ratio > kErrorChangeThreshold ||
          this_next_error_ratio > kErrorChangeThreshold ||
          next_intra_coded_error_ratio > kNextIntraCodedErrorRatioMin);
}

// Check whether the candidate can be a key frame.
// This is a rewrite of test_candidate_kf().
static bool TestCandidateKey(const FirstpassInfo &first_pass_info,
                             int candidate_key_idx, int frames_since_prev_key) {
  const auto &stats_list = first_pass_info.stats_list;
  const int stats_count = static_cast<int>(stats_list.size());
  if (candidate_key_idx + 1 >= stats_count || candidate_key_idx - 1 < 0) {
    return false;
  }
  const auto &last_stats = stats_list[candidate_key_idx - 1];
  const auto &this_stats = stats_list[candidate_key_idx];
  const auto &next_stats = stats_list[candidate_key_idx + 1];

  if (frames_since_prev_key < 3) return false;
  const double second_ref_usage_threshold =
      GetSecondRefUsageThreshold(frames_since_prev_key);
  if (this_stats.pcnt_second_ref >= second_ref_usage_threshold) return false;
  if (next_stats.pcnt_second_ref >= second_ref_usage_threshold) return false;

  // Hard threshold where the first pass chooses intra for almost all blocks.
  // In such a case even if the frame is not a scene cut coding a key frame
  // may be a good option.
  constexpr double kVeryLowInterThreshold = 0.05;
  if (this_stats.pcnt_inter < kVeryLowInterThreshold ||
      DetectSlideTransition(this_stats, last_stats, next_stats) ||
      DetectIntraInterErrorChange(this_stats, last_stats, next_stats)) {
    double boost_score = 0.0;
    double decay_accumulator = 1.0;

    // We do "-1" because the candidate key is not counted.
    int stats_after_this_stats = stats_count - candidate_key_idx - 1;

    // Number of frames required to test for scene cut detection
    constexpr int kSceneCutKeyTestIntervalMax = 16;

    // Make sure we have enough stats after the candidate key.
    const int frames_to_test_after_candidate_key =
        std::min(kSceneCutKeyTestIntervalMax, stats_after_this_stats);

    // Examine how well the key frame predicts subsequent frames.
    int i;
    for (i = 1; i <= frames_to_test_after_candidate_key; ++i) {
      // Get the next frame details
      const auto &stats = stats_list[candidate_key_idx + i];

      // Cumulative effect of decay in prediction quality.
      if (stats.pcnt_inter > 0.85) {
        decay_accumulator *= stats.pcnt_inter;
      } else {
        decay_accumulator *= (0.85 + stats.pcnt_inter) / 2.0;
      }

      constexpr double kBoostFactor = 12.5;
      double next_iiratio =
          (kBoostFactor * stats.intra_error / ModifyDivisor(stats.coded_error));
      next_iiratio = std::min(next_iiratio, 128.0);
      double boost_score_increment = decay_accumulator * next_iiratio;

      // Keep a running total.
      boost_score += boost_score_increment;

      // Test various breakout clauses.
      // TODO(any): Test of intra error should be normalized to an MB.
      // TODO(angiebird): Investigate the following questions.
      // Question 1: next_iiratio (intra_error / coded_error) * kBoostFactor
      // We know intra_error / coded_error >= 1 and kBoostFactor = 12.5,
      // therefore, (intra_error / coded_error) * kBoostFactor will always
      // greater than 1.5. Is "next_iiratio < 1.5" always false?
      // Question 2: Similar to question 1, is "next_iiratio < 3.0" always true?
      // Question 3: Why do we need to divide 200 with num_mbs_16x16?
      if ((stats.pcnt_inter < 0.05) || (next_iiratio < 1.5) ||
          (((stats.pcnt_inter - stats.pcnt_neutral) < 0.20) &&
           (next_iiratio < 3.0)) ||
          (boost_score_increment < 3.0) ||
          (stats.intra_error <
           (200.0 / static_cast<double>(first_pass_info.num_mbs_16x16)))) {
        break;
      }
    }

    // If there is tolerable prediction for at least the next 3 frames then
    // break out else discard this potential key frame and move on
    const int count_for_tolerable_prediction = 3;
    if (boost_score > 30.0 && (i > count_for_tolerable_prediction)) {
      return true;
    }
  }
  return false;
}

// Compute key frame location from first_pass_info.
std::vector<int> GetKeyFrameList(const FirstpassInfo &first_pass_info) {
  std::vector<int> key_frame_list;
  key_frame_list.push_back(0);  // The first frame is always a key frame
  int candidate_key_idx = 1;
  while (candidate_key_idx <
         static_cast<int>(first_pass_info.stats_list.size())) {
    const int frames_since_prev_key = candidate_key_idx - key_frame_list.back();
    // Check for a scene cut.
    const bool scenecut_detected = TestCandidateKey(
        first_pass_info, candidate_key_idx, frames_since_prev_key);
    if (scenecut_detected) {
      key_frame_list.push_back(candidate_key_idx);
    }
    ++candidate_key_idx;
  }
  return key_frame_list;
}

// initialize GF_GROUP_STATS
static void InitGFStats(GF_GROUP_STATS *gf_stats) {
  gf_stats->gf_group_err = 0.0;
  gf_stats->gf_group_raw_error = 0.0;
  gf_stats->gf_group_skip_pct = 0.0;
  gf_stats->gf_group_inactive_zone_rows = 0.0;

  gf_stats->mv_ratio_accumulator = 0.0;
  gf_stats->decay_accumulator = 1.0;
  gf_stats->zero_motion_accumulator = 1.0;
  gf_stats->loop_decay_rate = 1.0;
  gf_stats->last_loop_decay_rate = 1.0;
  gf_stats->this_frame_mv_in_out = 0.0;
  gf_stats->mv_in_out_accumulator = 0.0;
  gf_stats->abs_mv_in_out_accumulator = 0.0;

  gf_stats->avg_sr_coded_error = 0.0;
  gf_stats->avg_pcnt_second_ref = 0.0;
  gf_stats->avg_new_mv_count = 0.0;
  gf_stats->avg_wavelet_energy = 0.0;
  gf_stats->avg_raw_err_stdev = 0.0;
  gf_stats->non_zero_stdev_count = 0;
}

static int FindRegionIndex(const std::vector<REGIONS> &regions, int frame_idx) {
  for (int k = 0; k < static_cast<int>(regions.size()); k++) {
    if (regions[k].start <= frame_idx && regions[k].last >= frame_idx) {
      return k;
    }
  }
  return -1;
}

// This function detects a flash through the high relative pcnt_second_ref
// score in the frame following a flash frame. The offset passed in should
// reflect this.
static bool DetectFlash(const std::vector<FIRSTPASS_STATS> &stats_list,
                        int index) {
  int next_index = index + 1;
  if (next_index >= static_cast<int>(stats_list.size())) return false;
  const FIRSTPASS_STATS &next_frame = stats_list[next_index];

  // What we are looking for here is a situation where there is a
  // brief break in prediction (such as a flash) but subsequent frames
  // are reasonably well predicted by an earlier (pre flash) frame.
  // The recovery after a flash is indicated by a high pcnt_second_ref
  // compared to pcnt_inter.
  return next_frame.pcnt_second_ref > next_frame.pcnt_inter &&
         next_frame.pcnt_second_ref >= 0.5;
}

#define MIN_SHRINK_LEN 6

// This function takes in a suggesting gop interval from cur_start to cur_last,
// analyzes firstpass stats and region stats and then return a better gop cut
// location.
// TODO(b/231517281): Simplify the indices once we have an unit test.
// We are using four indices here, order_index, cur_start, cur_last, and
// frames_since_key. Ideally, only three indices are needed.
// 1) start_index = order_index + cur_start
// 2) end_index = order_index + cur_end
// 3) key_index
int FindBetterGopCut(const std::vector<FIRSTPASS_STATS> &stats_list,
                     const std::vector<REGIONS> &regions_list,
                     int min_gop_show_frame_count, int max_gop_show_frame_count,
                     int order_index, int cur_start, int cur_last,
                     int frames_since_key) {
  // only try shrinking if interval smaller than active_max_gf_interval
  if (cur_last - cur_start > max_gop_show_frame_count ||
      cur_start >= cur_last) {
    return cur_last;
  }
  int num_regions = static_cast<int>(regions_list.size());
  int num_stats = static_cast<int>(stats_list.size());
  const int min_shrink_int = std::max(MIN_SHRINK_LEN, min_gop_show_frame_count);

  // find the region indices of where the first and last frame belong.
  int k_start = FindRegionIndex(regions_list, cur_start + frames_since_key);
  int k_last = FindRegionIndex(regions_list, cur_last + frames_since_key);
  if (cur_start + frames_since_key == 0) k_start = 0;

  int scenecut_idx = -1;
  // See if we have a scenecut in between
  for (int r = k_start + 1; r <= k_last; r++) {
    if (regions_list[r].type == SCENECUT_REGION &&
        regions_list[r].last - frames_since_key - cur_start >
            min_gop_show_frame_count) {
      scenecut_idx = r;
      break;
    }
  }

  // if the found scenecut is very close to the end, ignore it.
  if (scenecut_idx >= 0 &&
      regions_list[num_regions - 1].last - regions_list[scenecut_idx].last <
          4) {
    scenecut_idx = -1;
  }

  if (scenecut_idx != -1) {
    // If we have a scenecut, then stop at it.
    // TODO(bohanli): add logic here to stop before the scenecut and for
    // the next gop start from the scenecut with GF
    int is_minor_sc =
        (regions_list[scenecut_idx].avg_cor_coeff *
             (1 - stats_list[order_index + regions_list[scenecut_idx].start -
                             frames_since_key]
                          .noise_var /
                      regions_list[scenecut_idx].avg_intra_err) >
         0.6);
    cur_last =
        regions_list[scenecut_idx].last - frames_since_key - !is_minor_sc;
  } else {
    int is_last_analysed =
        (k_last == num_regions - 1) &&
        (cur_last + frames_since_key == regions_list[k_last].last);
    int not_enough_regions =
        k_last - k_start <= 1 + (regions_list[k_start].type == SCENECUT_REGION);
    // if we are very close to the end, then do not shrink since it may
    // introduce intervals that are too short
    if (!(is_last_analysed && not_enough_regions)) {
      const double arf_length_factor = 0.1;
      double best_score = 0;
      int best_j = -1;
      const int first_frame = regions_list[0].start - frames_since_key;
      const int last_frame =
          regions_list[num_regions - 1].last - frames_since_key;
      // score of how much the arf helps the whole GOP
      double base_score = 0.0;
      // Accumulate base_score in
      for (int j = cur_start + 1; j < cur_start + min_shrink_int; j++) {
        if (order_index + j >= num_stats) break;
        base_score = (base_score + 1.0) * stats_list[order_index + j].cor_coeff;
      }
      int met_blending = 0;   // Whether we have met blending areas before
      int last_blending = 0;  // Whether the previous frame if blending
      for (int j = cur_start + min_shrink_int; j <= cur_last; j++) {
        if (order_index + j >= num_stats) break;
        base_score = (base_score + 1.0) * stats_list[order_index + j].cor_coeff;
        int this_reg = FindRegionIndex(regions_list, j + frames_since_key);
        if (this_reg < 0) continue;
        // A GOP should include at most 1 blending region.
        if (regions_list[this_reg].type == BLENDING_REGION) {
          last_blending = 1;
          if (met_blending) {
            break;
          } else {
            base_score = 0;
            continue;
          }
        } else {
          if (last_blending) met_blending = 1;
          last_blending = 0;
        }

        // Add the factor of how good the neighborhood is for this
        // candidate arf.
        double this_score = arf_length_factor * base_score;
        double temp_accu_coeff = 1.0;
        // following frames
        int count_f = 0;
        for (int n = j + 1; n <= j + 3 && n <= last_frame; n++) {
          if (order_index + n >= num_stats) break;
          temp_accu_coeff *= stats_list[order_index + n].cor_coeff;
          this_score +=
              temp_accu_coeff *
              sqrt(AOMMAX(
                  0.5, 1 - stats_list[order_index + n].noise_var /
                               AOMMAX(stats_list[order_index + n].intra_error,
                                      0.001)));
          count_f++;
        }
        // preceding frames
        temp_accu_coeff = 1.0;
        for (int n = j; n > j - 3 * 2 + count_f && n > first_frame; n--) {
          if (order_index + n < 0) break;
          temp_accu_coeff *= stats_list[order_index + n].cor_coeff;
          this_score +=
              temp_accu_coeff *
              sqrt(AOMMAX(
                  0.5, 1 - stats_list[order_index + n].noise_var /
                               AOMMAX(stats_list[order_index + n].intra_error,
                                      0.001)));
        }

        if (this_score > best_score) {
          best_score = this_score;
          best_j = j;
        }
      }

      // For blending areas, move one more frame in case we missed the
      // first blending frame.
      int best_reg = FindRegionIndex(regions_list, best_j + frames_since_key);
      if (best_reg < num_regions - 1 && best_reg > 0) {
        if (regions_list[best_reg - 1].type == BLENDING_REGION &&
            regions_list[best_reg + 1].type == BLENDING_REGION) {
          if (best_j + frames_since_key == regions_list[best_reg].start &&
              best_j + frames_since_key < regions_list[best_reg].last) {
            best_j += 1;
          } else if (best_j + frames_since_key == regions_list[best_reg].last &&
                     best_j + frames_since_key > regions_list[best_reg].start) {
            best_j -= 1;
          }
        }
      }

      if (cur_last - best_j < 2) best_j = cur_last;
      if (best_j > 0 && best_score > 0.1) cur_last = best_j;
      // if cannot find anything, just cut at the original place.
    }
  }

  return cur_last;
}

// Function to test for a condition where a complex transition is followed
// by a static section. For example in slide shows where there is a fade
// between slides. This is to help with more optimal kf and gf positioning.
static bool DetectTransitionToStill(
    const std::vector<FIRSTPASS_STATS> &stats_list, int next_stats_index,
    int min_gop_show_frame_count, int frame_interval, int still_interval,
    double loop_decay_rate, double last_decay_rate) {
  // Break clause to detect very still sections after motion
  // For example a static image after a fade or other transition
  // instead of a clean scene cut.
  if (frame_interval > min_gop_show_frame_count && loop_decay_rate >= 0.999 &&
      last_decay_rate < 0.9) {
    int stats_count = static_cast<int>(stats_list.size());
    int stats_left = stats_count - next_stats_index;
    if (stats_left >= still_interval) {
      // Look ahead a few frames to see if static condition persists...
      int j;
      for (j = 0; j < still_interval; ++j) {
        const FIRSTPASS_STATS &stats = stats_list[next_stats_index + j];
        if (stats.pcnt_inter - stats.pcnt_motion < 0.999) break;
      }
      // Only if it does do we signal a transition to still.
      return j == still_interval;
    }
  }
  return false;
}

static int DetectGopCut(const std::vector<FIRSTPASS_STATS> &stats_list,
                        int start_idx, int candidate_cut_idx, int next_key_idx,
                        int flash_detected, int min_gop_show_frame_count,
                        int max_gop_show_frame_count, int frame_width,
                        int frame_height, const GF_GROUP_STATS &gf_stats) {
  (void)max_gop_show_frame_count;
  const int candidate_gop_size = candidate_cut_idx - start_idx;

  if (!flash_detected) {
    // Break clause to detect very still sections after motion. For example,
    // a static image after a fade or other transition.
    if (DetectTransitionToStill(stats_list, start_idx, min_gop_show_frame_count,
                                candidate_gop_size, 5, gf_stats.loop_decay_rate,
                                gf_stats.last_loop_decay_rate)) {
      return 1;
    }
    const double arf_abs_zoom_thresh = 4.4;
    // Motion breakout threshold for loop below depends on image size.
    const double mv_ratio_accumulator_thresh =
        (frame_height + frame_width) / 4.0;
    // Some conditions to breakout after min interval.
    if (candidate_gop_size >= min_gop_show_frame_count &&
        // If possible don't break very close to a kf
        (next_key_idx - candidate_cut_idx >= min_gop_show_frame_count) &&
        (candidate_gop_size & 0x01) &&
        (gf_stats.mv_ratio_accumulator > mv_ratio_accumulator_thresh ||
         gf_stats.abs_mv_in_out_accumulator > arf_abs_zoom_thresh)) {
      return 1;
    }
  }

  // TODO(b/231489624): Check if we need this part.
  // If almost totally static, we will not use the the max GF length later,
  // so we can continue for more frames.
  // if ((candidate_gop_size >= active_max_gf_interval + 1) &&
  //     !is_almost_static(gf_stats->zero_motion_accumulator,
  //                       twopass->kf_zeromotion_pct, cpi->ppi->lap_enabled)) {
  //   return 0;
  // }
  return 0;
}

/*!\brief Determine the length of future GF groups.
 *
 * \ingroup gf_group_algo
 * This function decides the gf group length of future frames in batch
 *
 * \param[in]    rc_param         Rate control parameters
 * \param[in]    stats_list       List of first pass stats
 * \param[in]    regions_list     List of regions from av1_identify_regions
 * \param[in]    order_index      Index of current frame in stats_list
 * \param[in]    frames_since_key Number of frames since the last key frame
 * \param[in]    frames_to_key    Number of frames to the next key frame
 *
 * \return Returns a vector of decided GF group lengths.
 */
static std::vector<int> PartitionGopIntervals(
    const RateControlParam &rc_param,
    const std::vector<FIRSTPASS_STATS> &stats_list,
    const std::vector<REGIONS> &regions_list, int order_index,
    int frames_since_key, int frames_to_key) {
  int i = 0;
  // If cpi->gf_state.arf_gf_boost_lst is 0, we are starting with a KF or GF.
  int cur_start = 0;
  // Each element is the last frame of the previous GOP. If there are n GOPs,
  // you need n + 1 cuts to find the durations. So cut_pos starts out with -1,
  // which is the last frame of the previous GOP.
  std::vector<int> cut_pos(1, -1);
  int cut_here = 0;
  GF_GROUP_STATS gf_stats;
  InitGFStats(&gf_stats);
  int num_stats = static_cast<int>(stats_list.size());

  while (i + order_index < num_stats) {
    // reaches next key frame, break here
    if (i >= frames_to_key - 1) {
      cut_here = 2;
    } else if (i - cur_start >= rc_param.max_gop_show_frame_count) {
      // reached maximum len, but nothing special yet (almost static)
      // let's look at the next interval
      cut_here = 2;
    } else {
      // Test for the case where there is a brief flash but the prediction
      // quality back to an earlier frame is then restored.
      const int gop_start_idx = cur_start + order_index;
      const int candidate_gop_cut_idx = i + order_index;
      const int next_key_idx = frames_to_key + order_index;
      const bool flash_detected =
          DetectFlash(stats_list, candidate_gop_cut_idx);

      // TODO(bohanli): remove redundant accumulations here, or unify
      // this and the ones in define_gf_group
      const FIRSTPASS_STATS *stats = &stats_list[candidate_gop_cut_idx];
      av1_accumulate_next_frame_stats(stats, flash_detected, frames_since_key,
                                      i, &gf_stats, rc_param.frame_width,
                                      rc_param.frame_height);

      // TODO(angiebird): Can we simplify this part? Looks like we are going to
      // change the gop cut index with FindBetterGopCut() anyway.
      cut_here = DetectGopCut(
          stats_list, gop_start_idx, candidate_gop_cut_idx, next_key_idx,
          flash_detected, rc_param.min_gop_show_frame_count,
          rc_param.max_gop_show_frame_count, rc_param.frame_width,
          rc_param.frame_height, gf_stats);
    }

    if (!cut_here) {
      ++i;
      continue;
    }

    // the current last frame in the gf group
    int original_last = cut_here > 1 ? i : i - 1;
    int cur_last = FindBetterGopCut(
        stats_list, regions_list, rc_param.min_gop_show_frame_count,
        rc_param.max_gop_show_frame_count, order_index, cur_start,
        original_last, frames_since_key);
    // only try shrinking if interval smaller than active_max_gf_interval
    cut_pos.push_back(cur_last);

    // reset pointers to the shrunken location
    cur_start = cur_last;
    int cur_region_idx =
        FindRegionIndex(regions_list, cur_start + 1 + frames_since_key);
    if (cur_region_idx >= 0)
      if (regions_list[cur_region_idx].type == SCENECUT_REGION) cur_start++;

    // reset accumulators
    InitGFStats(&gf_stats);
    i = cur_last + 1;

    if (cut_here == 2 && i >= frames_to_key) break;
  }

  std::vector<int> gf_intervals;
  // save intervals
  for (size_t n = 1; n < cut_pos.size(); n++) {
    gf_intervals.push_back(cut_pos[n] - cut_pos[n - 1]);
  }

  return gf_intervals;
}

// Make a copy of the first pass stats, and analyze them
FirstpassInfo AnalyzeFpStats(FirstpassInfo firstpass_info) {
  const int stats_size = static_cast<int>(firstpass_info.stats_list.size());
  av1_mark_flashes(firstpass_info.stats_list.data(),
                   firstpass_info.stats_list.data() + stats_size);
  av1_estimate_noise(firstpass_info.stats_list.data(),
                     firstpass_info.stats_list.data() + stats_size);
  av1_estimate_coeff(firstpass_info.stats_list.data(),
                     firstpass_info.stats_list.data() + stats_size);
  return firstpass_info;
}

StatusOr<GopStructList> AV1RateControlQMode::DetermineGopInfo(
    const FirstpassInfo &firstpass_info) {
  const int stats_size = static_cast<int>(firstpass_info.stats_list.size());
  if (stats_size <= 0) {
    Status status;
    status.code = AOM_CODEC_INVALID_PARAM;
    status.message = "The firstpass info length is insufficient.";
    return status;
  }
  GopStructList gop_list;
  RefFrameManager ref_frame_manager(rc_param_.ref_frame_table_size,
                                    rc_param_.max_ref_frames);

  const FirstpassInfo analyzed_fp_info =
      AnalyzeFpStats(std::move(firstpass_info));

  int global_coding_idx_offset = 0;
  int global_order_idx_offset = 0;
  std::vector<int> key_frame_list = GetKeyFrameList(analyzed_fp_info);
  key_frame_list.push_back(stats_size);  // a sentinel value
  for (size_t ki = 0; ki + 1 < key_frame_list.size(); ++ki) {
    int frames_to_key = key_frame_list[ki + 1] - key_frame_list[ki];
    int key_order_index = key_frame_list[ki];  // The key frame's display order

    std::vector<REGIONS> regions_list(MAX_FIRSTPASS_ANALYSIS_FRAMES);
    int total_regions = 0;
    av1_identify_regions(analyzed_fp_info.stats_list.data() + key_order_index,
                         frames_to_key, 0, regions_list.data(), &total_regions);
    regions_list.resize(total_regions);
    std::vector<int> gf_intervals = PartitionGopIntervals(
        rc_param_, analyzed_fp_info.stats_list, regions_list, key_order_index,
        /*frames_since_key=*/0, frames_to_key);
    for (size_t gi = 0; gi < gf_intervals.size(); ++gi) {
      const bool has_key_frame = gi == 0;
      const int show_frame_count = gf_intervals[gi];
      GopStruct gop =
          ConstructGop(&ref_frame_manager, show_frame_count, has_key_frame,
                       global_coding_idx_offset, global_order_idx_offset);
      assert(gop.show_frame_count == show_frame_count);
      global_coding_idx_offset += static_cast<int>(gop.gop_frame_list.size());
      global_order_idx_offset += gop.show_frame_count;
      gop_list.push_back(gop);
    }
  }

  // Determine the qp adjustment ratio for this gop.
  double global_avg_coded_error = 0.0;
  for (int i = 0; i < stats_size; ++i) {
    global_avg_coded_error +=
        log(1.0 + std::min(analyzed_fp_info.stats_list[i].coded_error,
                           analyzed_fp_info.stats_list[i].sr_coded_error));
  }
  global_avg_coded_error /= static_cast<double>(stats_size);

  for (auto &gop_struct : gop_list) {
    double gop_avg_coded_error = 0.0;
    for (int i = gop_struct.global_order_idx_offset;
         i < gop_struct.global_order_idx_offset + gop_struct.show_frame_count;
         ++i) {
      gop_avg_coded_error +=
          log(1.0 + std::min(analyzed_fp_info.stats_list[i].coded_error,
                             analyzed_fp_info.stats_list[i].sr_coded_error));
    }
    gop_avg_coded_error /=
        std::max(static_cast<double>(gop_struct.show_frame_count), 1.0);

    gop_struct.base_q_ratio =
        fabs(global_avg_coded_error - gop_avg_coded_error) < 0.001
            ? 1.0
            : exp((global_avg_coded_error - gop_avg_coded_error) * 2.0);
  }
  return gop_list;
}

TplFrameDepStats CreateTplFrameDepStats(int frame_height, int frame_width,
                                        int min_block_size,
                                        bool has_alt_stats) {
  const int unit_rows = (frame_height + min_block_size - 1) / min_block_size;
  const int unit_cols = (frame_width + min_block_size - 1) / min_block_size;
  TplFrameDepStats frame_dep_stats;
  frame_dep_stats.unit_size = min_block_size;
  frame_dep_stats.unit_stats.resize(unit_rows);
  for (auto &row : frame_dep_stats.unit_stats) {
    row.resize(unit_cols);
  }

  if (has_alt_stats) {
    frame_dep_stats.alt_unit_stats.resize(unit_rows);
    for (auto &row : frame_dep_stats.alt_unit_stats) {
      row.resize(unit_cols);
    }
  }
  return frame_dep_stats;
}

TplUnitDepStats TplBlockStatsToDepStats(const TplBlockStats &block_stats,
                                        int unit_count,
                                        bool rate_dist_present) {
  TplUnitDepStats dep_stats = {};
  if (rate_dist_present) {
    dep_stats.intra_cost = block_stats.intra_pred_err * 1.0 / unit_count;
    dep_stats.inter_cost = block_stats.inter_pred_err * 1.0 / unit_count;
  } else {
    dep_stats.intra_cost = block_stats.intra_cost * 1.0 / unit_count;
    dep_stats.inter_cost = block_stats.inter_cost * 1.0 / unit_count;
  }
  // In rare case, inter_cost may be greater than intra_cost.
  // If so, we need to modify inter_cost such that inter_cost <= intra_cost
  // because it is required by GetPropagationFraction()
  if (block_stats.ref_frame_index[0] >= 0)
    dep_stats.inter_cost = std::min(dep_stats.intra_cost, dep_stats.inter_cost);
  else
    dep_stats.inter_cost = dep_stats.intra_cost;
  dep_stats.mv = block_stats.mv;
  dep_stats.ref_frame_index = block_stats.ref_frame_index;
  return dep_stats;
}

namespace {
Status ValidateBlockStats(const TplFrameStats &frame_stats,
                          const TplBlockStats &block_stats,
                          int min_block_size) {
  if (block_stats.col >= frame_stats.frame_width ||
      block_stats.row >= frame_stats.frame_height) {
    std::ostringstream error_message;
    error_message << "Block position (" << block_stats.col << ", "
                  << block_stats.row
                  << ") is out of range; frame dimensions are "
                  << frame_stats.frame_width << " x "
                  << frame_stats.frame_height;
    return { AOM_CODEC_INVALID_PARAM, error_message.str() };
  }
  if (block_stats.col % min_block_size != 0 ||
      block_stats.row % min_block_size != 0 ||
      block_stats.width % min_block_size != 0 ||
      block_stats.height % min_block_size != 0) {
    std::ostringstream error_message;
    error_message
        << "Invalid block position or dimension, must be a multiple of "
        << min_block_size << "; col = " << block_stats.col
        << ", row = " << block_stats.row << ", width = " << block_stats.width
        << ", height = " << block_stats.height;
    return { AOM_CODEC_INVALID_PARAM, error_message.str() };
  }
  return { AOM_CODEC_OK, "" };
}

Status ValidateTplStats(const GopStruct &gop_struct,
                        const TplGopStats &tpl_gop_stats) {
  constexpr char kAdvice[] =
      "Do the current RateControlParam settings match those used to generate "
      "the TPL stats?";
  if (gop_struct.gop_frame_list.size() !=
      tpl_gop_stats.frame_stats_list.size()) {
    std::ostringstream error_message;
    error_message << "Frame count of GopStruct ("
                  << gop_struct.gop_frame_list.size()
                  << ") doesn't match frame count of TPL stats ("
                  << tpl_gop_stats.frame_stats_list.size() << "). " << kAdvice;
    return { AOM_CODEC_INVALID_PARAM, error_message.str() };
  }
  for (int i = 0; i < static_cast<int>(gop_struct.gop_frame_list.size()); ++i) {
    const bool is_ref_frame = gop_struct.gop_frame_list[i].update_ref_idx >= 0;
    const bool has_tpl_stats =
        !tpl_gop_stats.frame_stats_list[i].block_stats_list.empty();
    if (is_ref_frame && !has_tpl_stats) {
      std::ostringstream error_message;
      error_message << "The frame with global_coding_idx "
                    << gop_struct.gop_frame_list[i].global_coding_idx
                    << " is a reference frame, but has no TPL stats. "
                    << kAdvice;
      return { AOM_CODEC_INVALID_PARAM, error_message.str() };
    }
  }
  return { AOM_CODEC_OK, "" };
}
}  // namespace

Status FillTplUnitDepStats(
    std::vector<std::vector<TplUnitDepStats>> &unit_stats,
    const TplFrameStats &frame_stats,
    const std::vector<TplBlockStats> &block_stats_list) {
  const int min_block_size = frame_stats.min_block_size;
  const int unit_rows =
      (frame_stats.frame_height + min_block_size - 1) / min_block_size;
  const int unit_cols =
      (frame_stats.frame_width + min_block_size - 1) / min_block_size;
  for (const TplBlockStats &block_stats : block_stats_list) {
    Status status =
        ValidateBlockStats(frame_stats, block_stats, min_block_size);
    if (!status.ok()) {
      return status;
    }
    const int block_unit_row = block_stats.row / min_block_size;
    const int block_unit_col = block_stats.col / min_block_size;
    // The block must start within the frame boundaries, but it may extend past
    // the right edge or bottom of the frame. Find the number of unit rows and
    // columns in the block which are fully within the frame.
    const int block_unit_rows = std::min(block_stats.height / min_block_size,
                                         unit_rows - block_unit_row);
    const int block_unit_cols = std::min(block_stats.width / min_block_size,
                                         unit_cols - block_unit_col);
    const int unit_count = block_unit_rows * block_unit_cols;
    TplUnitDepStats this_unit_stats = TplBlockStatsToDepStats(
        block_stats, unit_count, frame_stats.rate_dist_present);
    for (int r = 0; r < block_unit_rows; r++) {
      for (int c = 0; c < block_unit_cols; c++) {
        unit_stats[block_unit_row + r][block_unit_col + c] = this_unit_stats;
      }
    }
  }
  return { AOM_CODEC_OK, "" };
}

StatusOr<TplFrameDepStats> CreateTplFrameDepStatsWithoutPropagation(
    const TplFrameStats &frame_stats) {
  if (frame_stats.block_stats_list.empty()) {
    return TplFrameDepStats();
  }
  const int min_block_size = frame_stats.min_block_size;
  TplFrameDepStats frame_dep_stats = CreateTplFrameDepStats(
      frame_stats.frame_height, frame_stats.frame_width, min_block_size,
      !frame_stats.alternate_block_stats_list.empty());

  Status status = FillTplUnitDepStats(frame_dep_stats.unit_stats, frame_stats,
                                      frame_stats.block_stats_list);
  if (!status.ok()) return status;

  if (!frame_stats.alternate_block_stats_list.empty()) {
    status = FillTplUnitDepStats(frame_dep_stats.alt_unit_stats, frame_stats,
                                 frame_stats.alternate_block_stats_list);
    if (!status.ok()) return status;
    frame_dep_stats.rdcost =
        TplFrameDepStatsAccumulateInterCost(frame_dep_stats.unit_stats);
    frame_dep_stats.alt_rdcost =
        TplFrameDepStatsAccumulateInterCost(frame_dep_stats.alt_unit_stats);
  }

  return frame_dep_stats;
}

int GetRefCodingIdxList(const TplUnitDepStats &unit_dep_stats,
                        const RefFrameTable &ref_frame_table,
                        int *ref_coding_idx_list) {
  int ref_frame_count = 0;
  for (int i = 0; i < kBlockRefCount; ++i) {
    ref_coding_idx_list[i] = -1;
    int ref_frame_index = unit_dep_stats.ref_frame_index[i];
    if (ref_frame_index != -1) {
      assert(ref_frame_index < static_cast<int>(ref_frame_table.size()));
      ref_coding_idx_list[i] = ref_frame_table[ref_frame_index].coding_idx;
      ref_frame_count++;
    }
  }
  return ref_frame_count;
}

int GetBlockOverlapArea(int r0, int c0, int r1, int c1, int size) {
  const int r_low = std::max(r0, r1);
  const int r_high = std::min(r0 + size, r1 + size);
  const int c_low = std::max(c0, c1);
  const int c_high = std::min(c0 + size, c1 + size);
  if (r_high >= r_low && c_high >= c_low) {
    return (r_high - r_low) * (c_high - c_low);
  }
  return 0;
}

// TODO(angiebird): Merge TplFrameDepStatsAccumulateIntraCost and
// TplFrameDepStatsAccumulate.
double TplFrameDepStatsAccumulateIntraCost(
    const TplFrameDepStats &frame_dep_stats) {
  auto getIntraCost = [](double sum, const TplUnitDepStats &unit) {
    return sum + unit.intra_cost;
  };
  double sum = 0;
  for (const auto &row : frame_dep_stats.unit_stats) {
    sum = std::accumulate(row.begin(), row.end(), sum, getIntraCost);
  }
  return std::max(sum, 1.0);
}

double TplFrameDepStatsAccumulateInterCost(
    const std::vector<std::vector<TplUnitDepStats>> &unit_stats) {
  auto getInterCost = [](double sum, const TplUnitDepStats &unit) {
    return sum + unit.inter_cost;
  };
  double sum = 0;
  for (const auto &row : unit_stats) {
    sum = std::accumulate(row.begin(), row.end(), sum, getInterCost);
  }
  return std::max(sum, 1.0);
}

double TplFrameDepStatsAccumulate(const TplFrameDepStats &frame_dep_stats) {
  auto getOverallCost = [](double sum, const TplUnitDepStats &unit) {
    return sum + unit.propagation_cost + unit.intra_cost;
  };
  double sum = 0;
  for (const auto &row : frame_dep_stats.unit_stats) {
    sum = std::accumulate(row.begin(), row.end(), sum, getOverallCost);
  }
  return std::max(sum, 1.0);
}

// This is a generalization of GET_MV_RAWPEL that allows for an arbitrary
// number of fractional bits.
// TODO(angiebird): Add unit test to this function
int GetFullpelValue(int subpel_value, int subpel_bits) {
  const int subpel_scale = (1 << subpel_bits);
  const int sign = subpel_value >= 0 ? 1 : -1;
  int fullpel_value = (abs(subpel_value) + subpel_scale / 2) >> subpel_bits;
  fullpel_value *= sign;
  return fullpel_value;
}

double GetPropagationFraction(const TplUnitDepStats &unit_dep_stats) {
  assert(unit_dep_stats.intra_cost >= unit_dep_stats.inter_cost);
  return (unit_dep_stats.intra_cost - unit_dep_stats.inter_cost) /
         ModifyDivisor(unit_dep_stats.intra_cost);
}

void TplFrameDepStatsBackTrace(int coding_idx, GopFrameType update_type,
                               const RefFrameTable &ref_frame_table,
                               TplGopDepStats *tpl_gop_dep_stats) {
  assert(!tpl_gop_dep_stats->frame_dep_stats_list.empty());
  TplFrameDepStats *frame_dep_stats =
      &tpl_gop_dep_stats->frame_dep_stats_list[coding_idx];

  if (frame_dep_stats->unit_stats.empty()) return;
  if (frame_dep_stats->alt_unit_stats.empty()) return;

  const bool ignore_inter = update_type == GopFrameType::kRegularKey ||
                            update_type == GopFrameType::kRegularArf ||
                            update_type == GopFrameType::kRegularGolden;

  const int unit_size = frame_dep_stats->unit_size;
  const int frame_unit_rows =
      static_cast<int>(frame_dep_stats->unit_stats.size());
  const int frame_unit_cols =
      static_cast<int>(frame_dep_stats->unit_stats[0].size());
  for (int unit_row = 0; unit_row < frame_unit_rows; ++unit_row) {
    for (int unit_col = 0; unit_col < frame_unit_cols; ++unit_col) {
      TplUnitDepStats &unit_dep_stats =
          frame_dep_stats->unit_stats[unit_row][unit_col];
      TplUnitDepStats &alt_unit_dep_stats =
          frame_dep_stats->alt_unit_stats[unit_row][unit_col];

      int ref_coding_idx_list[kBlockRefCount] = { -1, -1 };
      int ref_frame_count = GetRefCodingIdxList(
          alt_unit_dep_stats, ref_frame_table, ref_coding_idx_list);
      if (ref_frame_count == 0) continue;
      MotionVector base_mv[2] = { alt_unit_dep_stats.mv[0],
                                  alt_unit_dep_stats.mv[1] };
      for (int i = 0; i < kBlockRefCount; ++i) {
        if (ref_coding_idx_list[i] == -1) continue;
        assert(
            ref_coding_idx_list[i] <
            static_cast<int>(tpl_gop_dep_stats->frame_dep_stats_list.size()));
        TplFrameDepStats &ref_frame_dep_stats =
            tpl_gop_dep_stats->frame_dep_stats_list[ref_coding_idx_list[i]];
        assert(!ref_frame_dep_stats.alt_unit_stats.empty());
        const auto &mv = base_mv[i];
        const int mv_row = GetFullpelValue(mv.row, mv.subpel_bits);
        const int mv_col = GetFullpelValue(mv.col, mv.subpel_bits);
        const int ref_pixel_r = unit_row * unit_size + mv_row;
        const int ref_pixel_c = unit_col * unit_size + mv_col;
        const int ref_unit_row_low =
            (unit_row * unit_size + mv_row) / unit_size;
        const int ref_unit_col_low =
            (unit_col * unit_size + mv_col) / unit_size;

        for (int j = 0; j < 2; ++j) {
          for (int k = 0; k < 2; ++k) {
            const int ref_unit_row = ref_unit_row_low + j;
            const int ref_unit_col = ref_unit_col_low + k;
            if (ref_unit_row >= 0 && ref_unit_row < frame_unit_rows &&
                ref_unit_col >= 0 && ref_unit_col < frame_unit_cols) {
              const int overlap_area = GetBlockOverlapArea(
                  ref_pixel_r, ref_pixel_c, ref_unit_row * unit_size,
                  ref_unit_col * unit_size, unit_size);
              const double overlap_ratio =
                  overlap_area * 1.0 / (unit_size * unit_size);
              const double propagation_ratio =
                  1.0 / ref_frame_count * overlap_ratio;
              TplUnitDepStats &alt_ref_unit_stats =
                  ref_frame_dep_stats
                      .alt_unit_stats[ref_unit_row][ref_unit_col];
              alt_ref_unit_stats.propagation_cost +=
                  ((ignore_inter ? 0.0 : alt_unit_dep_stats.inter_cost) +
                   alt_unit_dep_stats.propagation_cost) *
                  propagation_ratio;

              TplUnitDepStats &ref_unit_stats =
                  ref_frame_dep_stats.unit_stats[ref_unit_row][ref_unit_col];
              ref_unit_stats.propagation_cost +=
                  ((ignore_inter ? 0.0 : unit_dep_stats.inter_cost) +
                   unit_dep_stats.propagation_cost) *
                  propagation_ratio;
            }
          }
        }
      }
    }
  }
}

void TplFrameDepStatsPropagate(int coding_idx,
                               const RefFrameTable &ref_frame_table,
                               TplGopDepStats *tpl_gop_dep_stats) {
  assert(!tpl_gop_dep_stats->frame_dep_stats_list.empty());
  TplFrameDepStats *frame_dep_stats =
      &tpl_gop_dep_stats->frame_dep_stats_list[coding_idx];

  if (frame_dep_stats->unit_stats.empty()) return;

  const int unit_size = frame_dep_stats->unit_size;
  const int frame_unit_rows =
      static_cast<int>(frame_dep_stats->unit_stats.size());
  const int frame_unit_cols =
      static_cast<int>(frame_dep_stats->unit_stats[0].size());
  for (int unit_row = 0; unit_row < frame_unit_rows; ++unit_row) {
    for (int unit_col = 0; unit_col < frame_unit_cols; ++unit_col) {
      TplUnitDepStats &unit_dep_stats =
          frame_dep_stats->unit_stats[unit_row][unit_col];
      int ref_coding_idx_list[kBlockRefCount] = { -1, -1 };
      int ref_frame_count = GetRefCodingIdxList(unit_dep_stats, ref_frame_table,
                                                ref_coding_idx_list);
      if (ref_frame_count == 0) continue;
      for (int i = 0; i < kBlockRefCount; ++i) {
        if (ref_coding_idx_list[i] == -1) continue;
        assert(
            ref_coding_idx_list[i] <
            static_cast<int>(tpl_gop_dep_stats->frame_dep_stats_list.size()));
        TplFrameDepStats &ref_frame_dep_stats =
            tpl_gop_dep_stats->frame_dep_stats_list[ref_coding_idx_list[i]];
        assert(!ref_frame_dep_stats.unit_stats.empty());
        const auto &mv = unit_dep_stats.mv[i];
        const int mv_row = GetFullpelValue(mv.row, mv.subpel_bits);
        const int mv_col = GetFullpelValue(mv.col, mv.subpel_bits);
        const int ref_pixel_r = unit_row * unit_size + mv_row;
        const int ref_pixel_c = unit_col * unit_size + mv_col;
        const int ref_unit_row_low =
            (unit_row * unit_size + mv_row) / unit_size;
        const int ref_unit_col_low =
            (unit_col * unit_size + mv_col) / unit_size;

        for (int j = 0; j < 2; ++j) {
          for (int k = 0; k < 2; ++k) {
            const int ref_unit_row = ref_unit_row_low + j;
            const int ref_unit_col = ref_unit_col_low + k;
            if (ref_unit_row >= 0 && ref_unit_row < frame_unit_rows &&
                ref_unit_col >= 0 && ref_unit_col < frame_unit_cols) {
              const int overlap_area = GetBlockOverlapArea(
                  ref_pixel_r, ref_pixel_c, ref_unit_row * unit_size,
                  ref_unit_col * unit_size, unit_size);
              const double overlap_ratio =
                  overlap_area * 1.0 / (unit_size * unit_size);
              const double propagation_fraction =
                  GetPropagationFraction(unit_dep_stats);
              const double propagation_ratio =
                  1.0 / ref_frame_count * overlap_ratio * propagation_fraction;
              TplUnitDepStats &ref_unit_stats =
                  ref_frame_dep_stats.unit_stats[ref_unit_row][ref_unit_col];
              ref_unit_stats.propagation_cost +=
                  (unit_dep_stats.intra_cost +
                   unit_dep_stats.propagation_cost) *
                  propagation_ratio;
            }
          }
        }
      }
    }
  }
}

std::vector<RefFrameTable> AV1RateControlQMode::GetRefFrameTableList(
    const GopStruct &gop_struct,
    const std::vector<LookaheadStats> &lookahead_stats,
    RefFrameTable ref_frame_table) {
  if (gop_struct.global_coding_idx_offset == 0) {
    // For the first GOP, ref_frame_table need not be initialized. This is fine,
    // because the first frame (a key frame) will fully initialize it.
    ref_frame_table.assign(rc_param_.ref_frame_table_size, GopFrameInvalid());
  } else {
    // It's not the first GOP, so ref_frame_table must be valid.
    assert(static_cast<int>(ref_frame_table.size()) ==
           rc_param_.ref_frame_table_size);
    assert(std::all_of(ref_frame_table.begin(), ref_frame_table.end(),
                       std::mem_fn(&GopFrame::is_valid)));
    // Reset the frame processing order of the initial ref_frame_table.
    for (GopFrame &gop_frame : ref_frame_table) gop_frame.coding_idx = -1;
  }

  std::vector<RefFrameTable> ref_frame_table_list;
  ref_frame_table_list.push_back(ref_frame_table);
  for (const GopFrame &gop_frame : gop_struct.gop_frame_list) {
    if (gop_frame.is_key_frame) {
      ref_frame_table.assign(rc_param_.ref_frame_table_size, gop_frame);
    } else if (gop_frame.update_ref_idx != -1) {
      assert(gop_frame.update_ref_idx <
             static_cast<int>(ref_frame_table.size()));
      ref_frame_table[gop_frame.update_ref_idx] = gop_frame;
    }
    ref_frame_table_list.push_back(ref_frame_table);
  }

  int gop_size_offset = static_cast<int>(gop_struct.gop_frame_list.size());

  for (const auto &lookahead_stat : lookahead_stats) {
    for (GopFrame gop_frame : lookahead_stat.gop_struct->gop_frame_list) {
      if (gop_frame.is_key_frame) {
        ref_frame_table.assign(rc_param_.ref_frame_table_size, gop_frame);
      } else if (gop_frame.update_ref_idx != -1) {
        assert(gop_frame.update_ref_idx <
               static_cast<int>(ref_frame_table.size()));
        gop_frame.coding_idx += gop_size_offset;
        ref_frame_table[gop_frame.update_ref_idx] = gop_frame;
      }
      ref_frame_table_list.push_back(ref_frame_table);
    }
    gop_size_offset +=
        static_cast<int>(lookahead_stat.gop_struct->gop_frame_list.size());
  }

  return ref_frame_table_list;
}

StatusOr<TplGopDepStats> ComputeTplGopDepStats(
    const GopStruct &gop_struct, const TplGopStats &tpl_gop_stats,
    const std::vector<LookaheadStats> &lookahead_stats,
    const std::vector<RefFrameTable> &ref_frame_table_list) {
  std::vector<const TplFrameStats *> tpl_frame_stats_list_with_lookahead;
  for (const auto &tpl_frame_stats : tpl_gop_stats.frame_stats_list) {
    tpl_frame_stats_list_with_lookahead.push_back(&tpl_frame_stats);
  }
  for (const auto &lookahead_stat : lookahead_stats) {
    for (const auto &tpl_frame_stats :
         lookahead_stat.tpl_gop_stats->frame_stats_list) {
      tpl_frame_stats_list_with_lookahead.push_back(&tpl_frame_stats);
    }
  }

  const int frame_count =
      static_cast<int>(tpl_frame_stats_list_with_lookahead.size());

  // Create the struct to store TPL dependency stats
  TplGopDepStats tpl_gop_dep_stats;

  tpl_gop_dep_stats.frame_dep_stats_list.reserve(frame_count);
  for (int coding_idx = 0; coding_idx < frame_count; coding_idx++) {
    const StatusOr<TplFrameDepStats> tpl_frame_dep_stats =
        CreateTplFrameDepStatsWithoutPropagation(
            *tpl_frame_stats_list_with_lookahead[coding_idx]);
    if (!tpl_frame_dep_stats.ok()) {
      return tpl_frame_dep_stats.status();
    }
    tpl_gop_dep_stats.frame_dep_stats_list.push_back(
        std::move(*tpl_frame_dep_stats));
  }

  // Back propagation
  for (int coding_idx = frame_count - 1; coding_idx >= 0; coding_idx--) {
    auto &ref_frame_table = ref_frame_table_list[coding_idx];
    // TODO(angiebird): Handle/test the case where reference frame
    // is in the previous GOP

    if (tpl_frame_stats_list_with_lookahead[coding_idx]
            ->alternate_block_stats_list.empty()) {
      // One pass TPL run
      TplFrameDepStatsPropagate(coding_idx, ref_frame_table,
                                &tpl_gop_dep_stats);
    } else {
      // Two pass TPL runs

      int first_gop_size = static_cast<int>(gop_struct.gop_frame_list.size());
      GopFrameType update_type;
      if (coding_idx >= first_gop_size) {
        update_type = lookahead_stats[0]
                          .gop_struct[0]
                          .gop_frame_list[coding_idx - first_gop_size]
                          .update_type;
      } else {
        update_type = gop_struct.gop_frame_list[coding_idx].update_type;
      }

      TplFrameDepStatsBackTrace(coding_idx, update_type, ref_frame_table,
                                &tpl_gop_dep_stats);
    }
  }
  return tpl_gop_dep_stats;
}

static std::vector<uint8_t> SetupDeltaQ(const TplFrameDepStats &frame_dep_stats,
                                        int frame_width, int frame_height,
                                        int base_qindex,
                                        double frame_importance,
                                        bool use_twopass_data) {
  // TODO(jianj) : Add support to various superblock sizes.
  const int sb_size = 64;
  const int delta_q_res = 4;
  const int num_unit_per_sb = sb_size / frame_dep_stats.unit_size;
  const int sb_rows = (frame_height + sb_size - 1) / sb_size;
  const int sb_cols = (frame_width + sb_size - 1) / sb_size;
  const int unit_rows = (frame_height + frame_dep_stats.unit_size - 1) /
                        frame_dep_stats.unit_size;
  const int unit_cols =
      (frame_width + frame_dep_stats.unit_size - 1) / frame_dep_stats.unit_size;
  std::vector<uint8_t> superblock_q_indices;

  if (use_twopass_data) {
    // Cumulate frame level stats
    double cum_inter_cost = 0;
    double cum_rdcost_diff = 0;
    for (int sb_row = 0; sb_row < sb_rows; ++sb_row) {
      for (int sb_col = 0; sb_col < sb_cols; ++sb_col) {
        const int unit_row_start = sb_row * num_unit_per_sb;
        const int unit_row_end =
            std::min((sb_row + 1) * num_unit_per_sb, unit_rows);
        const int unit_col_start = sb_col * num_unit_per_sb;
        const int unit_col_end =
            std::min((sb_col + 1) * num_unit_per_sb, unit_cols);
        // A simplified version of av1_get_q_for_deltaq_objective()
        for (int unit_row = unit_row_start; unit_row < unit_row_end;
             ++unit_row) {
          for (int unit_col = unit_col_start; unit_col < unit_col_end;
               ++unit_col) {
            const TplUnitDepStats &unit_dep_stats =
                frame_dep_stats.unit_stats[unit_row][unit_col];
            const TplUnitDepStats &alt_unit_dep_stats =
                frame_dep_stats.alt_unit_stats[unit_row][unit_col];
            cum_inter_cost += unit_dep_stats.inter_cost;
            cum_rdcost_diff += std::max(unit_dep_stats.propagation_cost -
                                            alt_unit_dep_stats.propagation_cost,
                                        0.0);
          }
        }
      }
    }
    cum_inter_cost = std::max(cum_inter_cost, 1.0);
    frame_importance = (cum_rdcost_diff + cum_inter_cost) / cum_inter_cost;
  }

  // Calculate delta_q offset for each superblock.
  for (int sb_row = 0; sb_row < sb_rows; ++sb_row) {
    for (int sb_col = 0; sb_col < sb_cols; ++sb_col) {
      double intra_cost = 0;
      double mc_dep_cost = 0;
      const int unit_row_start = sb_row * num_unit_per_sb;
      const int unit_row_end =
          std::min((sb_row + 1) * num_unit_per_sb, unit_rows);
      const int unit_col_start = sb_col * num_unit_per_sb;
      const int unit_col_end =
          std::min((sb_col + 1) * num_unit_per_sb, unit_cols);
      // A simplified version of av1_get_q_for_deltaq_objective()
      for (int unit_row = unit_row_start; unit_row < unit_row_end; ++unit_row) {
        for (int unit_col = unit_col_start; unit_col < unit_col_end;
             ++unit_col) {
          const TplUnitDepStats &unit_dep_stats =
              frame_dep_stats.unit_stats[unit_row][unit_col];

          if (use_twopass_data) {
            const TplUnitDepStats &alt_unit_dep_stats =
                frame_dep_stats.alt_unit_stats[unit_row][unit_col];
            mc_dep_cost += std::max(unit_dep_stats.propagation_cost -
                                        alt_unit_dep_stats.propagation_cost,
                                    0.0);
            intra_cost += unit_dep_stats.inter_cost;
          } else {
            mc_dep_cost += unit_dep_stats.propagation_cost;
            intra_cost += unit_dep_stats.intra_cost;
          }
        }
      }
      intra_cost = std::max(intra_cost, 1.0);

      double beta = 1.0;
      const double r0 = 1 / frame_importance;
      const double rk = intra_cost / (mc_dep_cost + intra_cost);
      beta = r0 / rk;
      assert(beta > 0.0);

      int offset = av1_get_deltaq_offset(AOM_BITS_8, base_qindex, beta);
      offset = std::min(offset, delta_q_res * 9 - 1);
      offset = std::max(offset, -delta_q_res * 9 + 1);
      int qindex = offset + base_qindex;
      qindex = std::min(qindex, MAXQ);
      qindex = std::max(qindex, MINQ);
      qindex = av1_adjust_q_from_delta_q_res(delta_q_res, base_qindex, qindex);
      superblock_q_indices.push_back(static_cast<uint8_t>(qindex));
    }
  }

  return superblock_q_indices;
}

static std::unordered_map<int, double> FindKMeansClusterMap(
    const std::vector<uint8_t> &qindices,
    const std::vector<double> &centroids) {
  std::unordered_map<int, double> cluster_map;
  for (const uint8_t qindex : qindices) {
    double nearest_centroid = *std::min_element(
        centroids.begin(), centroids.end(),
        [qindex](const double centroid_a, const double centroid_b) {
          return fabs(centroid_a - qindex) < fabs(centroid_b - qindex);
        });
    cluster_map.insert({ qindex, nearest_centroid });
  }
  return cluster_map;
}

namespace internal {

std::unordered_map<int, int> KMeans(std::vector<uint8_t> qindices, int k) {
  std::vector<double> centroids;
  // Initialize the centroids with first k qindices
  std::unordered_set<int> qindices_set;

  for (const uint8_t qp : qindices) {
    if (!qindices_set.insert(qp).second) continue;  // Already added.
    centroids.push_back(qp);
    if (static_cast<int>(centroids.size()) >= k) break;
  }

  std::unordered_map<int, double> intermediate_cluster_map;
  while (true) {
    // Find the closest centroid for each qindex
    intermediate_cluster_map = FindKMeansClusterMap(qindices, centroids);
    // For each cluster, calculate the new centroids
    std::unordered_map<double, std::vector<int>> centroid_to_qindices;
    for (const auto &qindex_centroid : intermediate_cluster_map) {
      centroid_to_qindices[qindex_centroid.second].push_back(
          qindex_centroid.first);
    }
    bool centroids_changed = false;
    std::vector<double> new_centroids;
    for (const auto &cluster : centroid_to_qindices) {
      double sum = 0.0;
      for (const int qindex : cluster.second) {
        sum += qindex;
      }
      double new_centroid = sum / cluster.second.size();
      new_centroids.push_back(new_centroid);
      if (new_centroid != cluster.first) centroids_changed = true;
    }
    if (!centroids_changed) break;
    centroids = new_centroids;
  }
  std::unordered_map<int, int> cluster_map;
  for (const auto &qindex_centroid : intermediate_cluster_map) {
    cluster_map.insert(
        { qindex_centroid.first, static_cast<int>(qindex_centroid.second) });
  }
  return cluster_map;
}
}  // namespace internal

static int GetRDMult(const GopFrame &gop_frame, int q_index) {
  // TODO(angiebird):
  // 1) Check if these rdmult rules are good in our use case.
  // 2) Support high-bit-depth mode
  if (gop_frame.is_golden_frame) {
    // Assume ARF_UPDATE/GF_UPDATE share the same remult rule.
    return av1_compute_rd_mult_based_on_qindex(AOM_BITS_8, GF_UPDATE, q_index);
  } else if (gop_frame.is_key_frame) {
    return av1_compute_rd_mult_based_on_qindex(AOM_BITS_8, KF_UPDATE, q_index);
  } else {
    // Assume LF_UPDATE/OVERLAY_UPDATE/INTNL_OVERLAY_UPDATE/INTNL_ARF_UPDATE
    // share the same remult rule.
    return av1_compute_rd_mult_based_on_qindex(AOM_BITS_8, LF_UPDATE, q_index);
  }
}

// Check whether a frame (with index frame_index) uses candidate_reference as a
// reference frame.
bool CheckFrameUseReference(
    int frame_index, const GopFrame &frame, const GopFrame &candidate_reference,
    const std::vector<RefFrameTable> &ref_frame_table_list) {
  if (frame.update_type == GopFrameType::kOverlay ||
      frame.update_type == GopFrameType::kIntermediateOverlay) {
    return false;
  }
  for (const auto &ref_frame : frame.ref_frame_list) {
    if (ref_frame_table_list[frame_index][ref_frame.index].global_coding_idx ==
        candidate_reference.global_coding_idx) {
      return true;
    }
  }
  return false;
}

// Returns the number of frames that use this_frame as reference in the
// current and next subGop.
int CountUsedAsReference(const GopStruct &gop_struct,
                         const std::vector<LookaheadStats> &lookahead_stats,
                         const std::vector<RefFrameTable> &ref_frame_table_list,
                         const GopFrame &this_frame) {
  int num = 0;
  const int frame_count = static_cast<int>(gop_struct.gop_frame_list.size());
  // Check frames in this gop
  for (int i = 0; i < frame_count; ++i) {
    if (CheckFrameUseReference(i, gop_struct.gop_frame_list[i], this_frame,
                               ref_frame_table_list)) {
      ++num;
    }
  }
  // Check frames in the next gop
  if (!lookahead_stats.empty()) {
    const auto &next_gop_frame_list =
        lookahead_stats[0].gop_struct[0].gop_frame_list;
    const int next_gop_frame_count =
        static_cast<int>(next_gop_frame_list.size());
    for (int i = 0; i < next_gop_frame_count; ++i) {
      if (CheckFrameUseReference(frame_count + i, next_gop_frame_list[i],
                                 this_frame, ref_frame_table_list)) {
        ++num;
      }
    }
  }
  return num;
}

int GetIntArfQ(const GopStruct &gop_struct,
               const std::vector<LookaheadStats> &lookahead_stats,
               const std::vector<RefFrameTable> &ref_frame_table_list,
               const GopFrame &arf_frame, const GopFrame &int_arf_frame,
               int active_best_quality, int active_worst_quality) {
  if (!arf_frame.is_valid) return active_best_quality;
  assert(int_arf_frame.is_valid);

  // Check whether this is the first intermediate arf
  bool is_first_int_arf = false;
  for (const auto &gop_frame : gop_struct.gop_frame_list) {
    if (gop_frame.update_type == GopFrameType::kIntermediateArf) {
      is_first_int_arf =
          gop_frame.global_coding_idx == int_arf_frame.global_coding_idx;
      break;
    }
  }
  if (is_first_int_arf) {
    // int_arf_frame is the first intermediate arf in the subGop
    int arf_as_ref = CountUsedAsReference(gop_struct, lookahead_stats,
                                          ref_frame_table_list, arf_frame);
    int int_arf_as_ref = CountUsedAsReference(
        gop_struct, lookahead_stats, ref_frame_table_list, int_arf_frame);
    int arf_adjusted =
        arf_as_ref + static_cast<int>(int_arf_as_ref * kIntArfAdjFactor);
    if (arf_adjusted <= int_arf_as_ref) {
      return active_best_quality;
    } else {
      assert(arf_adjusted > 0);
      return active_best_quality +
             (active_worst_quality - active_best_quality) *
                 (arf_adjusted - int_arf_as_ref) / arf_adjusted;
    }
  } else {
    // int_arf_frame is not the first intermediate arf in the subGop
    assert(int_arf_frame.layer_depth >= 1);
    const int depth_factor = 1 << (int_arf_frame.layer_depth - 1);
    return (active_worst_quality * (depth_factor - 1) + active_best_quality) /
           depth_factor;
  }
}

StatusOr<GopEncodeInfo> AV1RateControlQMode::GetGopEncodeInfoWithNoStats(
    const GopStruct &gop_struct) {
  GopEncodeInfo gop_encode_info;
  const int frame_count = static_cast<int>(gop_struct.gop_frame_list.size());
  const int base_offset = av1_get_deltaq_offset(
      AOM_BITS_8, rc_param_.base_q_index, gop_struct.base_q_ratio);
  const int base_q_index = rc_param_.base_q_index + base_offset;
  for (int i = 0; i < frame_count; ++i) {
    FrameEncodeParameters param;
    const GopFrame &gop_frame = gop_struct.gop_frame_list[i];
    // Use constant QP for TPL pass encoding. Keep the functionality
    // that allows QP changes across sub-gop.
    param.q_index = base_q_index;
    param.rdmult = av1_compute_rd_mult_based_on_qindex(AOM_BITS_8, LF_UPDATE,
                                                       base_q_index);
    if (rc_param_.tpl_pass_count == TplPassCount::kTwoTplPasses) {
      if (gop_frame.update_type == GopFrameType::kRegularGolden ||
          gop_frame.update_type == GopFrameType::kRegularKey ||
          gop_frame.update_type == GopFrameType::kRegularArf) {
        if (rc_param_.tpl_pass_index) param.q_index = kSecondTplPassQp;
        param.rdmult = av1_compute_rd_mult_based_on_qindex(
            AOM_BITS_8, ARF_UPDATE, kSecondTplPassQp);
      }
    }
    gop_encode_info.param_list.push_back(param);
  }
  return gop_encode_info;
}

bool CheckFlash(const std::vector<FIRSTPASS_STATS> &stats_list, int index) {
  assert(index >= 1);
  return stats_list[index].is_flash || stats_list[index - 1].is_flash;
}

void inline SetUpFrameIndices(GopFrameType update_type, int stats_size,
                              int this_gop_len, int next_gop_len,
                              int &this_index, int &first_index,
                              int &last_index, int &ref_before_index,
                              int &ref_after_index) {
  if (update_type == GopFrameType::kRegularKey) {
    this_index = 0;
    first_index = 1;
    last_index = stats_size - 1;
    ref_before_index = -1;
    ref_after_index = -1;
  } else if (update_type == GopFrameType::kRegularGolden) {
    // TODO(b/260859962): Need to consider the situation when arf is not
    // used
    this_index = 0;
    first_index = 1;
    last_index = this_gop_len - 2;
    ref_before_index = -1;
    ref_after_index = this_gop_len - 1;
  } else {
    // arf type
    // TODO(b/260859962): It looks like in this case the last arf should
    // actually be at index -1. This for now should be accurate enough, but
    // in the future it is better to have the exact index of last arf.
    this_index = this_gop_len - 1;
    first_index = 1;
    last_index = next_gop_len >= 4 ? this_gop_len + next_gop_len - 2
                                   : this_gop_len + next_gop_len - 1;
    ref_before_index = 0;
    ref_after_index = next_gop_len >= 4 ? this_gop_len + next_gop_len - 1 : -1;
  }
}

// Return the accumulated score of a frame, considering its influence on the
// frames from first_index to last_index (both inclusive). When ref_before_index
// >= 0, only consider the frames where the current frame has a larger
// correlation than the frame at ref_before_index. Same for ref_after_index.
// This function also calculates and returns the average correlation coefficient
// of this frame to the affected frames through the parameter avg_correlation.
double GetAccumulatedScore(const FirstpassInfo &firstpass_info, int this_index,
                           int first_index, int last_index,
                           int ref_before_index, int ref_after_index,
                           double &avg_correlation) {
  assert(ref_before_index < 0 || ref_before_index < first_index);
  assert(ref_after_index < 0 || ref_after_index > last_index);
  double score = 0.0;
  int count = 0;
  avg_correlation = 0.0;
  // Check the influence of this frame to the frames before it
  for (int f = this_index - 1; f >= first_index; --f) {
    // The contribution of this frame to frame f
    double coeff_this = 1.0;
    for (int k = this_index; k > f; --k) {
      if (CheckFlash(firstpass_info.stats_list, k)) continue;
      coeff_this *= firstpass_info.stats_list[k].cor_coeff;
    }
    // The contribution of frame at ref_before_index to frame f
    if (ref_before_index >= 0) {
      double coeff_last = 1.0;
      for (int k = ref_before_index + 1; k <= f; ++k) {
        if (CheckFlash(firstpass_info.stats_list, k)) continue;
        coeff_last *= firstpass_info.stats_list[k].cor_coeff;
      }
      if (coeff_last > coeff_this) break;
    }
    ++count;
    avg_correlation += firstpass_info.stats_list[f + 1].cor_coeff;

    // If this is a flash, although we ignore it in the accumulation, we
    // still count it for this frame so it will probably have a low
    // correlation
    if (firstpass_info.stats_list[f].is_flash)
      coeff_this *= firstpass_info.stats_list[f].cor_coeff;

    const double this_cor =
        coeff_this * sqrt(std::max((firstpass_info.stats_list[f].intra_error -
                                    firstpass_info.stats_list[f].noise_var) /
                                       firstpass_info.stats_list[f].intra_error,
                                   0.5));
    score += this_cor;
  }

  // Check the influence of this frame to the frames after it
  for (int f = this_index + 1; f <= last_index; ++f) {
    // The contribution of this frame to frame f
    double coeff_this = 1.0;
    for (int k = this_index + 1; k <= f; ++k) {
      if (CheckFlash(firstpass_info.stats_list, k)) continue;
      coeff_this *= firstpass_info.stats_list[k].cor_coeff;
    }

    // The contribution of frame at ref_after_index to frame f
    if (ref_after_index >= 0) {
      double coeff_next = 1.0;
      for (int k = ref_after_index; k > f; --k) {
        if (CheckFlash(firstpass_info.stats_list, k)) continue;
        coeff_next *= firstpass_info.stats_list[k].cor_coeff;
      }
      if (coeff_next > coeff_this) break;
    }
    ++count;
    avg_correlation += firstpass_info.stats_list[f].cor_coeff;

    // If this is a flash, although we ignore it in the accumulation, we
    // still count it for this frame so it will probably have a low
    // correlation
    if (firstpass_info.stats_list[f].is_flash)
      coeff_this *= firstpass_info.stats_list[f].cor_coeff;

    const double this_cor =
        coeff_this * sqrt(std::max((firstpass_info.stats_list[f].intra_error -
                                    firstpass_info.stats_list[f].noise_var) /
                                       firstpass_info.stats_list[f].intra_error,
                                   0.5));
    score += this_cor;
  }
  if (count > 0) avg_correlation /= static_cast<double>(count);
  return score;
}

int AdjustStaticQp(double avg_correlation, double score, int q_index) {
  if (avg_correlation < 0.99) return q_index;
  const double factor = q_index * score / 400 + 1.0;

  return static_cast<int>(q_index / factor);
}

StatusOr<GopEncodeInfo> AV1RateControlQMode::GetGopEncodeInfoWithFp(
    const GopStruct &gop_struct, const FirstpassInfo &firstpass_info,
    const std::vector<LookaheadStats> &lookahead_stats,
    const RefFrameTable &ref_frame_table_snapshot_init) {
  const int frame_count = static_cast<int>(gop_struct.gop_frame_list.size());
  GopEncodeInfo gop_encode_info;

  const std::vector<RefFrameTable> ref_frame_table_list = GetRefFrameTableList(
      gop_struct, lookahead_stats, ref_frame_table_snapshot_init);
  gop_encode_info.final_snapshot = ref_frame_table_list[frame_count];

  const int stats_size = static_cast<int>(firstpass_info.stats_list.size());
  const FirstpassInfo analyzed_fp_info =
      AnalyzeFpStats(std::move(firstpass_info));

  const int this_gop_len = gop_struct.show_frame_count;
  const int next_gop_len =
      lookahead_stats.empty()
          ? 0
          : lookahead_stats[0].gop_struct[0].show_frame_count;
  if (stats_size < this_gop_len + next_gop_len) {
    Status status;
    status.code = AOM_CODEC_INVALID_PARAM;
    status.message = "The firstpass info length is insufficient.";
    return status;
  }

  GopFrame arf_frame = GopFrameInvalid();

  const int base_offset = av1_get_deltaq_offset(
      AOM_BITS_8, rc_param_.base_q_index, gop_struct.base_q_ratio);
  const int base_q_index = rc_param_.base_q_index + base_offset;
  const int active_worst_quality = base_q_index;
  int active_best_quality = base_q_index;
  for (int i = 0; i < frame_count; ++i) {
    FrameEncodeParameters param;
    const GopFrame &gop_frame = gop_struct.gop_frame_list[i];
    if (gop_frame.update_type == GopFrameType::kOverlay ||
        gop_frame.update_type == GopFrameType::kIntermediateOverlay ||
        gop_frame.update_type == GopFrameType::kRegularLeaf) {
      param.q_index = base_q_index;
    } else if (gop_frame.update_type == GopFrameType::kRegularKey ||
               gop_frame.update_type == GopFrameType::kRegularGolden ||
               gop_frame.update_type == GopFrameType::kRegularArf) {
      int this_index, first_index, last_index, ref_before_index,
          ref_after_index;
      SetUpFrameIndices(gop_frame.update_type, stats_size, this_gop_len,
                        next_gop_len, this_index, first_index, last_index,
                        ref_before_index, ref_after_index);

      double avg_correlation = 0;
      const double score = GetAccumulatedScore(
          analyzed_fp_info, this_index, first_index, last_index,
          ref_before_index, ref_after_index, avg_correlation);
      const double boost = std::min(
          std::max(sqrt(score), 1.0),
          gop_frame.update_type == GopFrameType::kRegularKey ? 6.0 : 4.0);
      const double qstep_ratio = 1.0 / boost;
      param.q_index = av1_get_q_index_from_qstep_ratio(base_q_index,
                                                       qstep_ratio, AOM_BITS_8);

      if (base_q_index) param.q_index = std::max(param.q_index, 1);
      active_best_quality = param.q_index;

      if (gop_frame.update_type == GopFrameType::kRegularArf) {
        arf_frame = gop_frame;
      }
    } else {
      param.q_index = GetIntArfQ(gop_struct, lookahead_stats,
                                 ref_frame_table_list, arf_frame, gop_frame,
                                 active_best_quality, active_worst_quality);
    }
    param.rdmult = GetRDMult(gop_frame, param.q_index);
    gop_encode_info.param_list.push_back(param);
  }
  return gop_encode_info;
}

StatusOr<GopEncodeInfo> AV1RateControlQMode::GetGopEncodeInfoWithTpl(
    const GopStruct &gop_struct, const FirstpassInfo &firstpass_info,
    const TplGopStats &tpl_gop_stats,
    const std::vector<LookaheadStats> &lookahead_stats,
    const RefFrameTable &ref_frame_table_snapshot_init) {
  assert(tpl_gop_stats.frame_stats_list.size() ==
         gop_struct.gop_frame_list.size());
  const int frame_count =
      static_cast<int>(tpl_gop_stats.frame_stats_list.size());

  GopEncodeInfo gop_encode_info;

  const std::vector<RefFrameTable> ref_frame_table_list = GetRefFrameTableList(
      gop_struct, lookahead_stats, ref_frame_table_snapshot_init);
  gop_encode_info.final_snapshot = ref_frame_table_list[frame_count];

  StatusOr<TplGopDepStats> gop_dep_stats = ComputeTplGopDepStats(
      gop_struct, tpl_gop_stats, lookahead_stats, ref_frame_table_list);
  if (!gop_dep_stats.ok()) {
    return gop_dep_stats.status();
  }

  const int stats_size = static_cast<int>(firstpass_info.stats_list.size());

  const int this_gop_len = gop_struct.show_frame_count;
  const int next_gop_len =
      lookahead_stats.empty()
          ? 0
          : lookahead_stats[0].gop_struct[0].show_frame_count;
  if (stats_size < this_gop_len + next_gop_len) {
    Status status;
    status.code = AOM_CODEC_INVALID_PARAM;
    status.message = "The firstpass info length is insufficient.";
    return status;
  }

  const int base_offset = av1_get_deltaq_offset(
      AOM_BITS_8, rc_param_.base_q_index, gop_struct.base_q_ratio);
  const int base_q_index = rc_param_.base_q_index + base_offset;

  const int active_worst_quality = base_q_index;
  int active_best_quality = base_q_index;

  double base_rdcost = 1.0;  // baseline total rdcost
  double hqr_rdcost = 0;     // high quality reference total rdcost
  double arf_rdcost_high = 1.0;

  bool kf_arf_seen = false;

  for (int i = 0; i < frame_count; ++i) {
    FrameEncodeParameters param;
    const GopFrame &gop_frame = gop_struct.gop_frame_list[i];
    const TplFrameDepStats &frame_dep_stats =
        gop_dep_stats->frame_dep_stats_list[i];
    if (gop_frame.update_type == GopFrameType::kRegularGolden ||
        gop_frame.update_type == GopFrameType::kRegularKey ||
        gop_frame.update_type == GopFrameType::kRegularArf) {
      if (!kf_arf_seen) {
        arf_rdcost_high += frame_dep_stats.rdcost;
      }
      kf_arf_seen = 1;
    } else {
      base_rdcost += frame_dep_stats.rdcost;
      hqr_rdcost += frame_dep_stats.alt_rdcost;
    }
  }

  double tp_frame_importance =
      1.0 + fabs((base_rdcost - hqr_rdcost) / arf_rdcost_high);

  for (int i = 0; i < frame_count; ++i) {
    FrameEncodeParameters param;
    const GopFrame &gop_frame = gop_struct.gop_frame_list[i];

    if (gop_frame.update_type == GopFrameType::kOverlay ||
        gop_frame.update_type == GopFrameType::kIntermediateOverlay ||
        gop_frame.update_type == GopFrameType::kRegularLeaf) {
      param.q_index = base_q_index;
    } else if (gop_frame.update_type == GopFrameType::kRegularGolden ||
               gop_frame.update_type == GopFrameType::kRegularKey ||
               gop_frame.update_type == GopFrameType::kRegularArf) {
      const TplFrameDepStats &frame_dep_stats =
          gop_dep_stats->frame_dep_stats_list[i];
      const double cost_without_propagation =
          TplFrameDepStatsAccumulateIntraCost(frame_dep_stats);
      const double cost_with_propagation =
          TplFrameDepStatsAccumulate(frame_dep_stats);
      double frame_importance =
          cost_with_propagation / cost_without_propagation;

      // TODO(jingning): Temporarily make the switch between single and
      // two TPL passes depending on the availability. This part of code
      // needs further modifications to support SB level calculation.
      if (rc_param_.tpl_pass_count == TplPassCount::kTwoTplPasses)
        frame_importance = tp_frame_importance;

      // Imitate the behavior of av1_tpl_get_qstep_ratio()
      const double qstep_ratio = sqrt(1 / frame_importance);
      param.q_index = av1_get_q_index_from_qstep_ratio(base_q_index,
                                                       qstep_ratio, AOM_BITS_8);

      if (base_q_index) param.q_index = AOMMAX(param.q_index, 1);
      active_best_quality = param.q_index;

      if (rc_param_.max_distinct_q_indices_per_frame > 1) {
        std::vector<uint8_t> superblock_q_indices;
        superblock_q_indices = SetupDeltaQ(
            frame_dep_stats, rc_param_.frame_width, rc_param_.frame_height,
            param.q_index, frame_importance,
            rc_param_.tpl_pass_count == TplPassCount::kTwoTplPasses);
        std::unordered_map<int, int> qindex_centroids = internal::KMeans(
            superblock_q_indices, rc_param_.max_distinct_q_indices_per_frame);
        for (size_t i = 0; i < superblock_q_indices.size(); ++i) {
          const int curr_sb_qindex =
              qindex_centroids.find(superblock_q_indices[i])->second;
          const int delta_q_res = 4;
          const int adjusted_qindex =
              param.q_index +
              (curr_sb_qindex - param.q_index) / delta_q_res * delta_q_res;
          const int rd_mult = GetRDMult(gop_frame, adjusted_qindex);
          param.superblock_encode_params.push_back(
              { static_cast<uint8_t>(adjusted_qindex), rd_mult });
        }
      }
    } else {
      // TODO(b/259601830): Also consider using GetIntArfQ here.
      // Intermediate ARFs
      assert(gop_frame.layer_depth >= 1);
      const int depth_factor = 1 << (gop_frame.layer_depth - 1);
      param.q_index =
          (active_worst_quality * (depth_factor - 1) + active_best_quality) /
          depth_factor;
    }
    param.rdmult = GetRDMult(gop_frame, param.q_index);
    gop_encode_info.param_list.push_back(param);
  }
  return gop_encode_info;
}

StatusOr<GopEncodeInfo> AV1RateControlQMode::GetTplPassGopEncodeInfo(
    const GopStruct &gop_struct,
    const FirstpassInfo &firstpass_info AOM_UNUSED) {
  return GetGopEncodeInfoWithNoStats(gop_struct);
}

StatusOr<GopEncodeInfo> AV1RateControlQMode::GetGopEncodeInfo(
    const GopStruct &gop_struct, const TplGopStats &tpl_gop_stats,
    const std::vector<LookaheadStats> &lookahead_stats,
    const FirstpassInfo &firstpass_info AOM_UNUSED,
    const RefFrameTable &ref_frame_table_snapshot_init) {
  // When TPL stats are not valid, use first pass stats.
  Status status = ValidateTplStats(gop_struct, tpl_gop_stats);
  if (!status.ok()) {
    return status;
  }

  for (const auto &lookahead_stat : lookahead_stats) {
    Status status = ValidateTplStats(*lookahead_stat.gop_struct,
                                     *lookahead_stat.tpl_gop_stats);
    if (!status.ok()) {
      return status;
    }
  }

  return GetGopEncodeInfoWithTpl(gop_struct, firstpass_info, tpl_gop_stats,
                                 lookahead_stats,
                                 ref_frame_table_snapshot_init);
}

}  // namespace aom
