Implement tpl model based delta quantization for key frames

BDRate on top of delta quantization for first layer altref
33 frames, speed 0
lowres: -0.196%(avg psnr), -0.291%(ovr psnr)
midres: -0.109%(avg_psnr), -0.245%(ovr psnr)

Includes some changes in row_mt to resolve some test failures.

STATS_CHANGED

Change-Id: Ib69c828139d0dcf6c098fd381e6f5ed7271446d2
diff --git a/av1/encoder/encode_strategy.c b/av1/encoder/encode_strategy.c
index ee324dd..6378bc0 100644
--- a/av1/encoder/encode_strategy.c
+++ b/av1/encoder/encode_strategy.c
@@ -1100,20 +1100,6 @@
   // parameter should be used with caution.
   frame_params.speed = oxcf->speed;
 
-  if (!frame_params.show_existing_frame) {
-    cm->using_qmatrix = cpi->oxcf.using_qm;
-    cm->min_qmlevel = cpi->oxcf.qm_minlevel;
-    cm->max_qmlevel = cpi->oxcf.qm_maxlevel;
-    if (oxcf->pass == 2) {
-      if (cpi->twopass.gf_group.index == 1 && cpi->oxcf.enable_tpl_model) {
-        av1_configure_buffer_updates(cpi, &frame_params, frame_update_type, 0);
-        av1_set_frame_size(cpi, cm->width, cm->height);
-        av1_tpl_setup_stats(cpi, &frame_input);
-        assert(cpi->num_gf_group_show_frames == 1);
-      }
-    }
-  }
-
   // Work out some encoding parameters specific to the pass:
   if (oxcf->pass == 0) {
     if (cpi->oxcf.rc_mode == AOM_CBR) {
@@ -1178,6 +1164,29 @@
   memcpy(frame_params.remapped_ref_idx, cm->remapped_ref_idx,
          REF_FRAMES * sizeof(*cm->remapped_ref_idx));
 
+#if ENABLE_KF_TPL
+  if (oxcf->pass == 2 && frame_params.frame_type == KEY_FRAME &&
+      frame_params.show_frame) {
+    av1_configure_buffer_updates(cpi, &frame_params, frame_update_type, 0);
+    av1_set_frame_size(cpi, cm->width, cm->height);
+    av1_tpl_setup_stats(cpi, &frame_input, 1);
+  }
+#endif  // ENABLE_KF_TPL
+
+  if (!frame_params.show_existing_frame) {
+    cm->using_qmatrix = cpi->oxcf.using_qm;
+    cm->min_qmlevel = cpi->oxcf.qm_minlevel;
+    cm->max_qmlevel = cpi->oxcf.qm_maxlevel;
+    if (oxcf->pass == 2) {
+      if (cpi->twopass.gf_group.index == 1 && cpi->oxcf.enable_tpl_model) {
+        av1_configure_buffer_updates(cpi, &frame_params, frame_update_type, 0);
+        av1_set_frame_size(cpi, cm->width, cm->height);
+        av1_tpl_setup_stats(cpi, &frame_input, 0);
+        assert(cpi->num_gf_group_show_frames == 1);
+      }
+    }
+  }
+
   if (av1_encode(cpi, dest, &frame_input, &frame_params, &frame_results) !=
       AOM_CODEC_OK) {
     return AOM_CODEC_ERROR;
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index b817409..331ce0f 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -4809,6 +4809,7 @@
     cpi->row_mt_sync_read_ptr = av1_row_mt_sync_read_dummy;
     cpi->row_mt_sync_write_ptr = av1_row_mt_sync_write_dummy;
     cpi->row_mt = 0;
+
     if (cpi->oxcf.row_mt && (cpi->oxcf.max_threads > 1) &&
         !cm->delta_q_info.delta_q_present_flag) {
       cpi->row_mt = 1;
diff --git a/av1/encoder/encoder.c b/av1/encoder/encoder.c
index 2896519..3f7fce2 100644
--- a/av1/encoder/encoder.c
+++ b/av1/encoder/encoder.c
@@ -3556,52 +3556,49 @@
 
 static void process_tpl_stats_frame(AV1_COMP *cpi) {
   AV1_COMMON *const cm = &cpi->common;
-  if (cpi->twopass.gf_group.index &&
-      cpi->twopass.gf_group.index < MAX_LAG_BUFFERS &&
-      cpi->oxcf.enable_tpl_model && cpi->tpl_model_pass == 0) {
-    assert(IMPLIES(cpi->twopass.gf_group.size > 0,
-                   cpi->twopass.gf_group.index < cpi->twopass.gf_group.size));
-    const int tpl_idx =
-        cpi->twopass.gf_group.frame_disp_idx[cpi->twopass.gf_group.index];
-    TplDepFrame *tpl_frame = &cpi->tpl_stats[tpl_idx];
-    TplDepStats *tpl_stats = tpl_frame->tpl_stats_ptr;
 
-    if (tpl_frame->is_valid) {
-      int tpl_stride = tpl_frame->stride;
-      int64_t intra_cost_base = 0;
-      int64_t mc_dep_cost_base = 0;
-      int64_t mc_saved_base = 0;
-      int64_t mc_count_base = 0;
-      int row, col;
+  assert(IMPLIES(cpi->twopass.gf_group.size > 0,
+                 cpi->twopass.gf_group.index < cpi->twopass.gf_group.size));
+  const int tpl_idx =
+      cpi->twopass.gf_group.frame_disp_idx[cpi->twopass.gf_group.index];
+  TplDepFrame *tpl_frame = &cpi->tpl_stats[tpl_idx];
+  TplDepStats *tpl_stats = tpl_frame->tpl_stats_ptr;
 
-      for (row = 0; row < cm->mi_rows; ++row) {
-        for (col = 0; col < cm->mi_cols; ++col) {
-          TplDepStats *this_stats = &tpl_stats[row * tpl_stride + col];
-          intra_cost_base += this_stats->intra_cost;
-          mc_dep_cost_base += this_stats->intra_cost + this_stats->mc_flow;
-          mc_count_base += this_stats->mc_count;
-          mc_saved_base += this_stats->mc_saved;
-        }
+  if (tpl_frame->is_valid) {
+    int tpl_stride = tpl_frame->stride;
+    int64_t intra_cost_base = 0;
+    int64_t mc_dep_cost_base = 0;
+    int64_t mc_saved_base = 0;
+    int64_t mc_count_base = 0;
+    int row, col;
+
+    for (row = 0; row < cm->mi_rows; ++row) {
+      for (col = 0; col < cm->mi_cols; ++col) {
+        TplDepStats *this_stats = &tpl_stats[row * tpl_stride + col];
+        intra_cost_base += this_stats->intra_cost;
+        mc_dep_cost_base += this_stats->intra_cost + this_stats->mc_flow;
+        mc_count_base += this_stats->mc_count;
+        mc_saved_base += this_stats->mc_saved;
       }
+    }
 
-      if (mc_dep_cost_base == 0) {
-        tpl_frame->is_valid = 0;
-      } else {
-        aom_clear_system_state();
-        cpi->rd.r0 = (double)intra_cost_base / mc_dep_cost_base;
-        if (is_frame_arf_and_tpl_eligible(cpi)) {
-          cpi->rd.arf_r0 = cpi->rd.r0;
-          const int gfu_boost = get_gfu_boost_from_r0(cpi->rd.arf_r0);
-          // printf("old boost %d new boost %d\n", cpi->rc.gfu_boost,
-          // gfu_boost);
-          cpi->rc.gfu_boost = (cpi->rc.gfu_boost + gfu_boost) / 2;
-        }
-        cpi->rd.mc_count_base =
-            (double)mc_count_base / (cm->mi_rows * cm->mi_cols);
-        cpi->rd.mc_saved_base =
-            (double)mc_saved_base / (cm->mi_rows * cm->mi_cols);
-        aom_clear_system_state();
+    if (mc_dep_cost_base == 0) {
+      tpl_frame->is_valid = 0;
+    } else {
+      aom_clear_system_state();
+      cpi->rd.r0 = (double)intra_cost_base / mc_dep_cost_base;
+      if (is_frame_arf_and_tpl_eligible(cpi)) {
+        cpi->rd.arf_r0 = cpi->rd.r0;
+        const int gfu_boost = get_gfu_boost_from_r0(cpi->rd.arf_r0);
+        // printf("old boost %d new boost %d\n", cpi->rc.gfu_boost,
+        // gfu_boost);
+        cpi->rc.gfu_boost = (cpi->rc.gfu_boost + gfu_boost) / 2;
       }
+      cpi->rd.mc_count_base =
+          (double)mc_count_base / (cm->mi_rows * cm->mi_cols);
+      cpi->rd.mc_saved_base =
+          (double)mc_saved_base / (cm->mi_rows * cm->mi_cols);
+      aom_clear_system_state();
     }
   }
 }
@@ -3614,7 +3611,9 @@
   // Setup variables that depend on the dimensions of the frame.
   av1_set_speed_features_framesize_dependent(cpi, cpi->speed);
 
-  if (is_frame_tpl_eligible(cpi)) process_tpl_stats_frame(cpi);
+  if (cpi->oxcf.enable_tpl_model && cpi->tpl_model_pass == 0 &&
+      is_frame_tpl_eligible(cpi))
+    process_tpl_stats_frame(cpi);
 
   // Decide q and q bounds.
   *q = av1_rc_pick_q_and_bounds(cpi, cm->width, cm->height, bottom_index,
diff --git a/av1/encoder/encoder.h b/av1/encoder/encoder.h
index 1d235b7..7f4165c 100644
--- a/av1/encoder/encoder.h
+++ b/av1/encoder/encoder.h
@@ -1324,17 +1324,15 @@
 // field.
 aom_fixed_buf_t *av1_get_global_headers(AV1_COMP *cpi);
 
+#define ENABLE_KF_TPL 1
 #define MAX_PYR_LEVEL_FROMTOP_DELTAQ 0
-static INLINE int is_frame_tpl_eligible(AV1_COMP *const cpi) {
-  const int max_pyr_level_fromtop_deltaq = MAX_PYR_LEVEL_FROMTOP_DELTAQ;
-  const int pyr_lev_from_top =
-      cpi->twopass.gf_group.pyramid_height -
-      cpi->twopass.gf_group.pyramid_level[cpi->twopass.gf_group.index];
-  if (pyr_lev_from_top > max_pyr_level_fromtop_deltaq ||
-      cpi->twopass.gf_group.pyramid_height <= max_pyr_level_fromtop_deltaq + 1)
-    return 0;
-  else
+
+static INLINE int is_frame_kf_and_tpl_eligible(AV1_COMP *const cpi) {
+  AV1_COMMON *cm = &cpi->common;
+  if (cm->current_frame.frame_type == KEY_FRAME && cm->show_frame)
     return 1;
+  else
+    return 0;
 }
 
 static INLINE int is_frame_arf_and_tpl_eligible(AV1_COMP *const cpi) {
@@ -1349,6 +1347,15 @@
     return 1;
 }
 
+static INLINE int is_frame_tpl_eligible(AV1_COMP *const cpi) {
+#if ENABLE_KF_TPL
+  return is_frame_kf_and_tpl_eligible(cpi) ||
+         is_frame_arf_and_tpl_eligible(cpi);
+#else
+  return is_frame_arf_and_tpl_eligible(cpi);
+#endif  // ENABLE_KF_TPL
+}
+
 #if CONFIG_COLLECT_PARTITION_STATS == 2
 static INLINE void av1_print_partition_stats(PartitionStats *part_stats) {
   FILE *f = fopen("partition_stats.csv", "w");
diff --git a/av1/encoder/ethread.c b/av1/encoder/ethread.c
index 259746b..1c170a8 100644
--- a/av1/encoder/ethread.c
+++ b/av1/encoder/ethread.c
@@ -385,7 +385,7 @@
                   aom_calloc(num_workers, sizeof(*cpi->tile_thr_data)));
 
 #if CONFIG_MULTITHREAD
-  if (cpi->row_mt == 1) {
+  if (cpi->oxcf.row_mt == 1) {
     if (cpi->row_mt_mutex_ == NULL) {
       CHECK_MEM_ERROR(cm, cpi->row_mt_mutex_,
                       aom_malloc(sizeof(*(cpi->row_mt_mutex_))));
@@ -473,7 +473,7 @@
       // Main thread acts as a worker and uses the thread data in cpi.
       thread_data->td = &cpi->td;
     }
-    if (cpi->row_mt == 1)
+    if (cpi->oxcf.row_mt == 1)
       CHECK_MEM_ERROR(
           cm, thread_data->td->tctx,
           (FRAME_CONTEXT *)aom_memalign(16, sizeof(*thread_data->td->tctx)));
diff --git a/av1/encoder/tpl_model.c b/av1/encoder/tpl_model.c
index 9d8a215..2417ec8 100644
--- a/av1/encoder/tpl_model.c
+++ b/av1/encoder/tpl_model.c
@@ -253,14 +253,20 @@
     }
   }
   best_intra_cost = AOMMAX(best_intra_cost, 1);
-  best_inter_cost = AOMMIN(best_intra_cost, (int64_t)best_inter_cost_weighted);
+  if (frame_idx == 0)
+    best_inter_cost = 0;
+  else
+    best_inter_cost =
+        AOMMIN(best_intra_cost, (int64_t)best_inter_cost_weighted);
   tpl_stats->inter_cost = best_inter_cost << TPL_DEP_COST_SCALE_LOG2;
   tpl_stats->intra_cost = best_intra_cost << TPL_DEP_COST_SCALE_LOG2;
 
-  const int idx = gf_group->ref_frame_gop_idx[frame_idx][best_rf_idx];
-  tpl_stats->ref_frame_index = idx;
-  tpl_stats->ref_disp_frame_index = cpi->twopass.gf_group.frame_disp_idx[idx];
-  tpl_stats->mv.as_int = best_mv.as_int;
+  if (frame_idx) {
+    const int idx = gf_group->ref_frame_gop_idx[frame_idx][best_rf_idx];
+    tpl_stats->ref_frame_index = idx;
+    tpl_stats->ref_disp_frame_index = cpi->twopass.gf_group.frame_disp_idx[idx];
+    tpl_stats->mv.as_int = best_mv.as_int;
+  }
 }
 
 static int round_floor(int ref_pos, int bsize_pix) {
@@ -466,6 +472,8 @@
       &sf, this_frame->y_crop_width, this_frame->y_crop_height,
       this_frame->y_crop_width, this_frame->y_crop_height);
 
+  xd->cur_buf = this_frame;
+
   if (is_cur_buf_hbd(xd))
     predictor = CONVERT_TO_BYTEPTR(predictor16);
   else
@@ -490,7 +498,7 @@
 
   xd->mi = cm->mi_grid_visible;
   xd->mi[0] = cm->mi;
-  xd->cur_buf = this_frame;
+  xd->block_ref_scale_factors[0] = &sf;
 
   const int base_qindex = gf_group->q_val[frame_idx];
   // Get rd multiplier set up.
@@ -528,8 +536,9 @@
       tpl_model_store(tpl_frame->tpl_stats_ptr, mi_row, mi_col, bsize,
                       tpl_frame->stride, &tpl_stats);
 
-      tpl_model_update(cpi->tpl_stats, tpl_frame->tpl_stats_ptr, mi_row, mi_col,
-                       bsize);
+      if (frame_idx)
+        tpl_model_update(cpi->tpl_stats, tpl_frame->tpl_stats_ptr, mi_row,
+                         mi_col, bsize);
     }
   }
 }
@@ -539,7 +548,8 @@
 static void init_gop_frames_for_tpl(AV1_COMP *cpi,
                                     YV12_BUFFER_CONFIG **gf_picture,
                                     GF_GROUP *gf_group, int *tpl_group_frames,
-                                    const EncodeFrameInput *const frame_input) {
+                                    const EncodeFrameInput *const frame_input,
+                                    int is_for_kf) {
   AV1_COMMON *cm = &cpi->common;
   const SequenceHeader *const seq_params = &cm->seq_params;
   int frame_idx = 0;
@@ -564,16 +574,20 @@
 
   *tpl_group_frames = 0;
 
-  // Initialize Golden reference frame.
-  RefCntBuffer *ref_buf = get_ref_frame_buf(cm, GOLDEN_FRAME);
-  gf_picture[0] = &ref_buf->buf;
-  ++*tpl_group_frames;
+  if (!is_for_kf) {
+    // Initialize Golden reference frame.
+    RefCntBuffer *ref_buf = get_ref_frame_buf(cm, GOLDEN_FRAME);
+    gf_picture[0] = &ref_buf->buf;
+    ++*tpl_group_frames;
+  }
+
+  int start_idx = !is_for_kf;
 
   // Initialize frames in the GF group
-  for (frame_idx = 1;
+  for (frame_idx = start_idx;
        frame_idx <= AOMMIN(gf_group->size, MAX_LENGTH_TPL_FRAME_STATS - 1);
        ++frame_idx) {
-    if (frame_idx == 1) {
+    if (frame_idx == start_idx) {
       gf_picture[frame_idx] = frame_input->source;
       frame_disp_idx = gf_group->frame_disp_idx[frame_idx];
     } else {
@@ -599,6 +613,8 @@
     ++*tpl_group_frames;
   }
 
+  if (is_for_kf) return;
+
   if (frame_idx < MAX_LENGTH_TPL_FRAME_STATS) {
     ++frame_disp_idx;
     int extend_frame_count = 0;
@@ -663,8 +679,7 @@
 }
 
 static void init_tpl_stats(AV1_COMP *cpi) {
-  int frame_idx;
-  for (frame_idx = 0; frame_idx < MAX_LENGTH_TPL_FRAME_STATS; ++frame_idx) {
+  for (int frame_idx = 0; frame_idx < MAX_LENGTH_TPL_FRAME_STATS; ++frame_idx) {
     TplDepFrame *tpl_frame = &cpi->tpl_stats[frame_idx];
     memset(tpl_frame->tpl_stats_ptr, 0,
            tpl_frame->height * tpl_frame->width *
@@ -674,19 +689,20 @@
 }
 
 void av1_tpl_setup_stats(AV1_COMP *cpi,
-                         const EncodeFrameInput *const frame_input) {
+                         const EncodeFrameInput *const frame_input,
+                         int is_for_kf) {
   YV12_BUFFER_CONFIG *gf_picture[MAX_LENGTH_TPL_FRAME_STATS];
   GF_GROUP *gf_group = &cpi->twopass.gf_group;
-  int frame_idx;
 
   init_gop_frames_for_tpl(cpi, gf_picture, gf_group, &cpi->tpl_gf_group_frames,
-                          frame_input);
+                          frame_input, is_for_kf);
 
   init_tpl_stats(cpi);
 
   if (cpi->oxcf.enable_tpl_model == 1) {
     // Backward propagation from tpl_group_frames to 1.
-    for (frame_idx = cpi->tpl_gf_group_frames - 1; frame_idx > 0; --frame_idx) {
+    for (int frame_idx = cpi->tpl_gf_group_frames - 1; frame_idx >= !is_for_kf;
+         --frame_idx) {
       if (gf_group->update_type[frame_idx] == OVERLAY_UPDATE ||
           gf_group->update_type[frame_idx] == INTNL_OVERLAY_UPDATE)
         continue;
@@ -733,6 +749,7 @@
   xd->left_mbmi = NULL;
   xd->mi[0]->sb_type = bsize;
   xd->mi[0]->motion_mode = SIMPLE_TRANSLATION;
+  xd->block_ref_scale_factors[0] = &sf;
 
   for (int mi_row = 0; mi_row < cm->mi_rows; mi_row += mi_height) {
     // Motion estimation row boundary
diff --git a/av1/encoder/tpl_model.h b/av1/encoder/tpl_model.h
index 4732d1c..36be9ba 100644
--- a/av1/encoder/tpl_model.h
+++ b/av1/encoder/tpl_model.h
@@ -17,7 +17,8 @@
 #endif
 
 void av1_tpl_setup_stats(AV1_COMP *cpi,
-                         const EncodeFrameInput *const frame_input);
+                         const EncodeFrameInput *const frame_input,
+                         int is_for_kf);
 
 void av1_tpl_setup_forward_stats(AV1_COMP *cpi);