RC: Implement setting key frame decision

Bug: aomedia:450252793
Change-Id: I9c8a52f918b13830232ea125248115afb454a05c
diff --git a/av1/encoder/av1_ext_ratectrl.c b/av1/encoder/av1_ext_ratectrl.c
index f7187a4..75182b6 100644
--- a/av1/encoder/av1_ext_ratectrl.c
+++ b/av1/encoder/av1_ext_ratectrl.c
@@ -167,6 +167,18 @@
   return AOM_CODEC_OK;
 }
 
+aom_codec_err_t av1_extrc_get_key_frame_decision(
+    AOM_EXT_RATECTRL *ext_ratectrl,
+    aom_rc_key_frame_decision_t *key_frame_decision) {
+  assert(ext_ratectrl != NULL);
+  if (!ext_ratectrl->ready || (ext_ratectrl->funcs.rc_type & AOM_RC_GOP) == 0) {
+    return AOM_CODEC_INVALID_PARAM;
+  }
+  aom_rc_status_t rc_status = ext_ratectrl->funcs.get_key_frame_decision(
+      ext_ratectrl->model, key_frame_decision);
+  return rc_status == AOM_RC_OK ? AOM_CODEC_OK : AOM_CODEC_ERROR;
+}
+
 aom_codec_err_t av1_extrc_delete(AOM_EXT_RATECTRL *ext_ratectrl) {
   if (ext_ratectrl == NULL) {
     return AOM_CODEC_INVALID_PARAM;
diff --git a/av1/encoder/pass2_strategy.c b/av1/encoder/pass2_strategy.c
index 68964d8..dac8758 100644
--- a/av1/encoder/pass2_strategy.c
+++ b/av1/encoder/pass2_strategy.c
@@ -3284,15 +3284,30 @@
   kf_raw_err = this_frame->intra_error;
   kf_mod_err = calculate_modified_err(frame_info, twopass, oxcf, this_frame);
 
-  // We assume the current frame is a key frame and we are looking for the next
-  // key frame. Therefore search_start_idx = 1
-  frames_to_key = define_kf_interval(cpi, firstpass_info, kf_cfg->key_freq_max,
-                                     /*search_start_idx=*/1);
-
-  if (frames_to_key != -1) {
-    rc->frames_to_key = AOMMIN(kf_cfg->key_freq_max, frames_to_key);
+  if (cpi->ext_ratectrl.ready &&
+      (cpi->ext_ratectrl.funcs.rc_type & AOM_RC_GOP) != 0 &&
+      cpi->ext_ratectrl.funcs.get_key_frame_decision != NULL) {
+    aom_rc_key_frame_decision_t key_frame_decision;
+    aom_codec_err_t codec_status = av1_extrc_get_key_frame_decision(
+        &cpi->ext_ratectrl, &key_frame_decision);
+    if (codec_status == AOM_CODEC_OK) {
+      rc->frames_to_key = key_frame_decision.key_frame_group_size;
+    } else {
+      aom_internal_error(cpi->common.error, codec_status,
+                         "av1_extrc_get_key_frame_decision() failed");
+    }
   } else {
-    rc->frames_to_key = kf_cfg->key_freq_max;
+    // We assume the current frame is a key frame and we are looking for the
+    // next key frame. Therefore search_start_idx = 1
+    frames_to_key =
+        define_kf_interval(cpi, firstpass_info, kf_cfg->key_freq_max,
+                           /*search_start_idx=*/1);
+
+    if (frames_to_key != -1) {
+      rc->frames_to_key = AOMMIN(kf_cfg->key_freq_max, frames_to_key);
+    } else {
+      rc->frames_to_key = kf_cfg->key_freq_max;
+    }
   }
 
   if (cpi->ppi->lap_enabled) correct_frames_to_key(cpi);
diff --git a/test/ext_ratectrl_test.cc b/test/ext_ratectrl_test.cc
index 269de04..5f950d3 100644
--- a/test/ext_ratectrl_test.cc
+++ b/test/ext_ratectrl_test.cc
@@ -45,6 +45,9 @@
 // A flag to indicate if update_encodeframe_result() is called.
 bool is_update_encodeframe_result_called = false;
 
+// A flag to indicate if get_key_frame_decision() is called.
+bool is_get_key_frame_decision_called = false;
+
 // Variables to store the parameters passed to update_encodeframe_result().
 int64_t bit_count = 0;
 int actual_encoding_qindex = 0;
@@ -53,6 +56,14 @@
 const int kGopFrameCount = kFrameNum + 1;
 aom_rc_gop_frame_t gop_frame_list[kGopFrameCount];
 
+aom_rc_status_t mock_get_key_frame_decision(
+    aom_rc_model_t /*ratectrl_model*/,
+    aom_rc_key_frame_decision_t *key_frame_decision) {
+  key_frame_decision->key_frame_group_size = 1;
+  is_get_key_frame_decision_called = true;
+  return AOM_RC_OK;
+}
+
 aom_rc_status_t mock_get_gop_decision(aom_rc_model_t /*ratectrl_model*/,
                                       aom_rc_gop_decision_t *gop_decision) {
   gop_decision->gop_frame_count = kGopFrameCount;
@@ -152,6 +163,7 @@
     rc_funcs->send_firstpass_stats = mock_send_firstpass_stats;
     rc_funcs->send_tpl_gop_stats = mock_send_extrc_tpl_gop_stats;
     rc_funcs->get_gop_decision = nullptr;
+    rc_funcs->get_key_frame_decision = nullptr;
     rc_funcs->get_encodeframe_decision = nullptr;
     rc_funcs->update_encodeframe_result = nullptr;
   }
@@ -167,6 +179,7 @@
     is_send_extrc_tpl_gop_stats_called = false;
     is_get_gop_decision_called = false;
     is_update_encodeframe_result_called = false;
+    is_get_key_frame_decision_called = false;
   }
 
   void PreEncodeFrameHook(::libaom_test::VideoSource *video,
@@ -326,4 +339,36 @@
 AV1_INSTANTIATE_TEST_SUITE(ExtRateCtrlGopTest,
                            ::testing::Values(::libaom_test::kTwoPassGood),
                            ::testing::Values(3));
+
+class ExtRateCtrlKeyFrameTest : public ExtRateCtrlTest {
+ protected:
+  ExtRateCtrlKeyFrameTest() {
+    rc_funcs_.rc_type = AOM_RC_GOP;
+    rc_funcs_.get_key_frame_decision = mock_get_key_frame_decision;
+  }
+
+  ~ExtRateCtrlKeyFrameTest() override = default;
+
+  void SetUp() override {
+    ExtRateCtrlTest::SetUp();
+    is_get_key_frame_decision_called = false;
+  }
+
+  void FramePktHook(const aom_codec_cx_pkt_t *pkt) override {
+    if (pkt->kind != AOM_CODEC_CX_FRAME_PKT) return;
+    EXPECT_TRUE(pkt->data.frame.flags & AOM_FRAME_IS_KEY);
+  }
+};
+
+TEST_P(ExtRateCtrlKeyFrameTest, TestExternalRateCtrlKeyFrame) {
+  ::libaom_test::Y4mVideoSource video("screendata.y4m", 0, kFrameNum);
+  ASSERT_NO_FATAL_FAILURE(RunLoop(&video));
+  EXPECT_TRUE(is_create_model_called);
+  EXPECT_TRUE(is_get_key_frame_decision_called);
+  EXPECT_TRUE(is_delete_model_called);
+}
+
+AV1_INSTANTIATE_TEST_SUITE(ExtRateCtrlKeyFrameTest,
+                           ::testing::Values(::libaom_test::kTwoPassGood),
+                           ::testing::Values(3));
 }  // namespace