feat: Add AOM_EFLAG_FREEZE_INTERNAL_STATE

Add a flag to encode a frame without updating encoder state.
This is useful for speculative encoding.

The flag ensures the current frame is encoded normally, PSNR is
generated if requested, but internal state like reference buffers,
CDF tables, and rate control is not modified.

Adds tests for single layer and SVC encoding.

Change-Id: I34e9e3d0beb355cff2f9b3419820773e5d8c550d
diff --git a/aom/aom_encoder.h b/aom/aom_encoder.h
index fa5f643..4b48cea 100644
--- a/aom/aom_encoder.h
+++ b/aom/aom_encoder.h
@@ -379,6 +379,9 @@
 #define AOM_EFLAG_FORCE_KF (1 << 0)
 /*!\brief Calculate PSNR for this frame, requires g_lag_in_frames to be 0 */
 #define AOM_EFLAG_CALCULATE_PSNR (1 << 1)
+/*!\brief Freeze internal state, do not update reference buffers, entropy
+ * tables, rate control state, etc. */
+#define AOM_EFLAG_FREEZE_INTERNAL_STATE (1 << 2)
 
 /*!\brief Encoder configuration structure
  *
diff --git a/av1/av1_cx_iface.c b/av1/av1_cx_iface.c
index e414f8a..252098c 100644
--- a/av1/av1_cx_iface.c
+++ b/av1/av1_cx_iface.c
@@ -3262,6 +3262,8 @@
   if (ppi->use_svc && ppi->cpi->svc.use_flexible_mode == 0 && flags == 0)
     av1_set_svc_fixed_mode(ppi->cpi);
 
+  ppi->b_freeze_internal_state = flags & AOM_EFLAG_FREEZE_INTERNAL_STATE;
+
   // Note(yunqing): While applying encoding flags, always start from enabling
   // all, and then modifying according to the flags. Previous frame's flags are
   // overwritten.
@@ -3495,6 +3497,9 @@
 
     // Call for LAP stage
     if (cpi_lap != NULL) {
+      if (cpi_lap->ppi->b_freeze_internal_state) {
+        av1_save_all_coding_context(cpi_lap);
+      }
       AV1_COMP_DATA cpi_lap_data = { 0 };
       cpi_lap_data.flush = !img;
       cpi_lap_data.timestamp_ratio = &ctx->timestamp_ratio;
@@ -3503,6 +3508,9 @@
         aom_internal_error_copy(&ppi->error, cpi_lap->common.error);
       }
       av1_post_encode_updates(cpi_lap, &cpi_lap_data);
+      if (cpi_lap->ppi->b_freeze_internal_state) {
+        restore_all_coding_context(cpi_lap);
+      }
     }
 
     // Recalculate the maximum number of frames that can be encoded in
@@ -3521,6 +3529,9 @@
       cpi->ref_idx_to_skip = INVALID_IDX;
       cpi->ref_refresh_index = INVALID_IDX;
       cpi->refresh_idx_available = false;
+      if (cpi->ppi->b_freeze_internal_state) {
+        av1_save_all_coding_context(cpi);
+      }
 
 #if CONFIG_FPMT_TEST
       simulate_parallel_frame =
@@ -3563,6 +3574,10 @@
       ppi->seq_params_locked = 1;
       av1_post_encode_updates(cpi, &cpi_data);
 
+      if (cpi->ppi->b_freeze_internal_state) {
+        restore_all_coding_context(cpi);
+      }
+
 #if CONFIG_ENTROPY_STATS
       if (ppi->cpi->oxcf.pass != 1 && !cpi->common.show_existing_frame)
         av1_accumulate_frame_counts(&ppi->aggregate_fc, &cpi->counts);
diff --git a/av1/encoder/encoder.c b/av1/encoder/encoder.c
index df01bae..423d0b0 100644
--- a/av1/encoder/encoder.c
+++ b/av1/encoder/encoder.c
@@ -5074,6 +5074,20 @@
   AV1_PRIMARY *const ppi = cpi->ppi;
   AV1_COMMON *const cm = &cpi->common;
 
+  if (ppi->b_freeze_internal_state) {
+    // Should not update encoder state, just necessary work to get the
+    // expected output and then return early.
+
+    // Note *size = 0 indicates a dropped frame for which psnr is not calculated
+    if (ppi->b_calculate_psnr && cpi_data->frame_size > 0) {
+      if (cm->show_existing_frame ||
+          (!is_stat_generation_stage(cpi) && cm->show_frame)) {
+        generate_psnr_packet(cpi);
+      }
+    }
+    return;
+  }
+
   update_gm_stats(cpi);
 
 #if !CONFIG_REALTIME_ONLY
diff --git a/av1/encoder/encoder.h b/av1/encoder/encoder.h
index ab19b1c..b31a7d9 100644
--- a/av1/encoder/encoder.h
+++ b/av1/encoder/encoder.h
@@ -2379,12 +2379,21 @@
   double max_scale;
 } WeberStats;
 
+/*!
+ * \brief This structure stores different types of frame indices.
+ */
+typedef struct {
+  int show_frame_count;
+} FRAME_INDEX_SET;
+
 typedef struct {
   struct loopfilter lf;
   CdefInfo cdef_info;
   YV12_BUFFER_CONFIG copy_buffer;
   RATE_CONTROL rc;
   MV_STATS mv_stats;
+  unsigned int frame_number;
+  FRAME_INDEX_SET frame_index_set;
 } CODING_CONTEXT;
 
 typedef struct {
@@ -2400,13 +2409,6 @@
   int subsampling_y;
 } FRAME_INFO;
 
-/*!
- * \brief This structure stores different types of frame indices.
- */
-typedef struct {
-  int show_frame_count;
-} FRAME_INDEX_SET;
-
 /*!\endcond */
 
 /*!
@@ -2888,6 +2890,13 @@
    * when --deltaq-mode=3.
    */
   AV1EncRowMultiThreadSync intra_row_mt_sync;
+
+  /*!
+   * If set to 1, the encoder should not update any internal state after
+   * completing the encode. E.g. no updates to reference buffers, CDF
+   * tables or RC state.
+   */
+  int b_freeze_internal_state;
 } AV1_PRIMARY;
 
 /*!
@@ -3767,9 +3776,9 @@
 
 void av1_initialize_enc(unsigned int usage, enum aom_rc_mode end_usage);
 
-struct AV1_COMP *av1_create_compressor(AV1_PRIMARY *ppi,
+struct AV1_COMP *av1_create_compressor(struct AV1_PRIMARY *ppi,
                                        const AV1EncoderConfig *oxcf,
-                                       BufferPool *const pool,
+                                       struct BufferPool *const pool,
                                        COMPRESSOR_STAGE stage,
                                        int lap_lag_in_frames);
 
diff --git a/av1/encoder/encoder_utils.c b/av1/encoder/encoder_utils.c
index 89ba731..07f4722 100644
--- a/av1/encoder/encoder_utils.c
+++ b/av1/encoder/encoder_utils.c
@@ -1544,6 +1544,8 @@
   cc->cdef_info = cm->cdef_info;
   cc->rc = cpi->rc;
   cc->mv_stats = cpi->ppi->mv_stats;
+  cc->frame_number = cpi->common.current_frame.frame_number;
+  cc->frame_index_set = cpi->frame_index_set;
 }
 
 void av1_save_all_coding_context(AV1_COMP *cpi) {
diff --git a/av1/encoder/encoder_utils.h b/av1/encoder/encoder_utils.h
index ca33eec..8691e48 100644
--- a/av1/encoder/encoder_utils.h
+++ b/av1/encoder/encoder_utils.h
@@ -889,6 +889,8 @@
   restore_cdef_coding_context(&cm->cdef_info, &cc->cdef_info);
   cpi->rc = cc->rc;
   cpi->ppi->mv_stats = cc->mv_stats;
+  cpi->common.current_frame.frame_number = cc->frame_number;
+  cpi->frame_index_set = cc->frame_index_set;
 }
 
 static inline int equal_dimensions_and_border(const YV12_BUFFER_CONFIG *a,
diff --git a/test/encode_api_test.cc b/test/encode_api_test.cc
index 94a24c6..874b421 100644
--- a/test/encode_api_test.cc
+++ b/test/encode_api_test.cc
@@ -1284,4 +1284,183 @@
 }
 #endif  // !CONFIG_REALTIME_ONLY
 
+TEST(EncodeAPI, FreezeInternalState) {
+  aom_codec_iface_t *iface = aom_codec_av1_cx();
+  aom_codec_enc_cfg_t cfg;
+  ASSERT_EQ(aom_codec_enc_config_default(iface, &cfg, kUsage), AOM_CODEC_OK);
+  cfg.g_w = 176;
+  cfg.g_h = 144;
+  cfg.rc_target_bitrate = 200;
+  cfg.g_lag_in_frames = 0;  // Needed for single frame updates
+
+  aom_codec_ctx_t enc;
+  ASSERT_EQ(aom_codec_enc_init(&enc, iface, &cfg, AOM_CODEC_USE_PSNR),
+            AOM_CODEC_OK);
+
+  aom_image_t *image = CreateGrayImage(AOM_IMG_FMT_I420, cfg.g_w, cfg.g_h);
+  ASSERT_NE(image, nullptr);
+
+  // Encode Frame A (Keyframe)
+  ASSERT_EQ(aom_codec_encode(&enc, image, /*pts=*/0, /*duration=*/1,
+                             /*flags=*/AOM_EFLAG_FORCE_KF),
+            AOM_CODEC_OK);
+
+  aom_codec_iter_t iter = nullptr;
+  while (aom_codec_get_cx_data(&enc, &iter) != nullptr) {
+    // Drain packets
+  }
+
+  std::vector<uint8_t> bitstream1;
+  bool psnr1 = false;
+
+  // Encode Frame B with freeze flag
+  ASSERT_EQ(aom_codec_encode(&enc, image, /*pts=*/1, /*duration=*/1,
+                             /*flags=*/AOM_EFLAG_FREEZE_INTERNAL_STATE |
+                                 AOM_EFLAG_CALCULATE_PSNR),
+            AOM_CODEC_OK);
+  iter = nullptr;
+  const aom_codec_cx_pkt_t *pkt;
+  while ((pkt = aom_codec_get_cx_data(&enc, &iter)) != nullptr) {
+    if (pkt->kind == AOM_CODEC_CX_FRAME_PKT) {
+      bitstream1.assign((uint8_t *)pkt->data.frame.buf,
+                        (uint8_t *)pkt->data.frame.buf + pkt->data.frame.sz);
+    } else if (pkt->kind == AOM_CODEC_PSNR_PKT) {
+      psnr1 = true;
+    }
+  }
+  EXPECT_TRUE(psnr1);
+  EXPECT_FALSE(bitstream1.empty());
+
+  std::vector<uint8_t> bitstream2;
+  bool psnr2 = false;
+
+  // Encode Frame B again without freeze flag
+  ASSERT_EQ(aom_codec_encode(&enc, image, /*pts=*/1, /*duration=*/1,
+                             /*flags=*/AOM_EFLAG_CALCULATE_PSNR),
+            AOM_CODEC_OK);
+  iter = nullptr;
+  while ((pkt = aom_codec_get_cx_data(&enc, &iter)) != nullptr) {
+    if (pkt->kind == AOM_CODEC_CX_FRAME_PKT) {
+      bitstream2.assign((uint8_t *)pkt->data.frame.buf,
+                        (uint8_t *)pkt->data.frame.buf + pkt->data.frame.sz);
+    } else if (pkt->kind == AOM_CODEC_PSNR_PKT) {
+      psnr2 = true;
+    }
+  }
+  EXPECT_TRUE(psnr2);
+  EXPECT_FALSE(bitstream2.empty());
+
+  // Bitstreams should be identical
+  EXPECT_EQ(bitstream1, bitstream2);
+
+  aom_img_free(image);
+  ASSERT_EQ(aom_codec_destroy(&enc), AOM_CODEC_OK);
+}
+
+TEST(EncodeAPI, FreezeInternalStateSVC) {
+  aom_codec_iface_t *iface = aom_codec_av1_cx();
+  aom_codec_enc_cfg_t cfg;
+  ASSERT_EQ(aom_codec_enc_config_default(iface, &cfg, AOM_USAGE_REALTIME),
+            AOM_CODEC_OK);
+  cfg.g_w = 176;
+  cfg.g_h = 144;
+  cfg.rc_target_bitrate = 300;
+  cfg.g_lag_in_frames = 0;
+  cfg.rc_end_usage = AOM_CBR;
+
+  aom_codec_ctx_t enc;
+  ASSERT_EQ(aom_codec_enc_init(&enc, iface, &cfg, AOM_CODEC_USE_PSNR),
+            AOM_CODEC_OK);
+
+  ASSERT_EQ(aom_codec_control(&enc, AOME_SET_CPUUSED, 7), AOM_CODEC_OK);
+
+  aom_svc_params_t svc_params = {};
+  svc_params.number_spatial_layers = 2;
+  svc_params.number_temporal_layers = 1;
+  svc_params.max_quantizers[0] = 56;
+  svc_params.min_quantizers[0] = 10;
+  svc_params.max_quantizers[1] = 56;
+  svc_params.min_quantizers[1] = 10;
+  svc_params.scaling_factor_num[0] = 1;
+  svc_params.scaling_factor_den[0] = 2;
+  svc_params.scaling_factor_num[1] = 1;
+  svc_params.scaling_factor_den[1] = 1;
+  svc_params.layer_target_bitrate[0] = cfg.rc_target_bitrate * 2 / 3;
+  svc_params.layer_target_bitrate[1] = cfg.rc_target_bitrate;
+  svc_params.framerate_factor[0] = 1;
+  ASSERT_EQ(aom_codec_control(&enc, AV1E_SET_SVC_PARAMS, &svc_params),
+            AOM_CODEC_OK);
+
+  aom_image_t *image = CreateGrayImage(AOM_IMG_FMT_I420, cfg.g_w, cfg.g_h);
+  ASSERT_NE(image, nullptr);
+
+  aom_svc_layer_id_t layer_id = {};
+  aom_svc_ref_frame_config_t ref_frame_config = {};
+
+  // Encode SL0 - Keyframe
+  layer_id.spatial_layer_id = 0;
+  ASSERT_EQ(aom_codec_control(&enc, AV1E_SET_SVC_LAYER_ID, &layer_id),
+            AOM_CODEC_OK);
+  ASSERT_EQ(aom_codec_encode(&enc, image, /*pts=*/0, /*duration=*/1,
+                             /*flags=*/AOM_EFLAG_FORCE_KF),
+            AOM_CODEC_OK);
+
+  aom_codec_iter_t iter = nullptr;
+  while (aom_codec_get_cx_data(&enc, &iter) != nullptr) {
+  }  // Drain
+
+  // Encode SL1 - Delta Frame with Freeze
+  layer_id.spatial_layer_id = 1;
+  ASSERT_EQ(aom_codec_control(&enc, AV1E_SET_SVC_LAYER_ID, &layer_id),
+            AOM_CODEC_OK);
+  ref_frame_config.refresh[0] = 1;
+  ref_frame_config.reference[0] = 1;
+  ref_frame_config.ref_idx[0] = 0;
+  ASSERT_EQ(
+      aom_codec_control(&enc, AV1E_SET_SVC_REF_FRAME_CONFIG, &ref_frame_config),
+      AOM_CODEC_OK);
+
+  std::vector<uint8_t> bitstream1;
+  bool psnr1 = false;
+  ASSERT_EQ(aom_codec_encode(&enc, image, /*pts=*/0, /*duration=*/1,
+                             /*flags=*/AOM_EFLAG_FREEZE_INTERNAL_STATE |
+                                 AOM_EFLAG_CALCULATE_PSNR),
+            AOM_CODEC_OK);
+  iter = nullptr;
+  const aom_codec_cx_pkt_t *pkt;
+  while ((pkt = aom_codec_get_cx_data(&enc, &iter)) != nullptr) {
+    if (pkt->kind == AOM_CODEC_CX_FRAME_PKT) {
+      bitstream1.assign((uint8_t *)pkt->data.frame.buf,
+                        (uint8_t *)pkt->data.frame.buf + pkt->data.frame.sz);
+    } else if (pkt->kind == AOM_CODEC_PSNR_PKT) {
+      psnr1 = true;
+    }
+  }
+  EXPECT_TRUE(psnr1);
+  EXPECT_FALSE(bitstream1.empty());
+
+  // Encode SL1 - Delta Frame again without Freeze
+  std::vector<uint8_t> bitstream2;
+  bool psnr2 = false;
+  ASSERT_EQ(aom_codec_encode(&enc, image, /*pts=*/0, /*duration=*/1,
+                             /*flags=*/AOM_EFLAG_CALCULATE_PSNR),
+            AOM_CODEC_OK);
+  iter = nullptr;
+  while ((pkt = aom_codec_get_cx_data(&enc, &iter)) != nullptr) {
+    if (pkt->kind == AOM_CODEC_CX_FRAME_PKT) {
+      bitstream2.assign((uint8_t *)pkt->data.frame.buf,
+                        (uint8_t *)pkt->data.frame.buf + pkt->data.frame.sz);
+    } else if (pkt->kind == AOM_CODEC_PSNR_PKT) {
+      psnr2 = true;
+    }
+  }
+  EXPECT_TRUE(psnr2);
+  EXPECT_FALSE(bitstream2.empty());
+
+  EXPECT_EQ(bitstream1, bitstream2);
+
+  aom_img_free(image);
+  ASSERT_EQ(aom_codec_destroy(&enc), AOM_CODEC_OK);
+}
+
 }  // namespace