Hierarchical lambda scaling within a SB using TPL

STATS_CHANGED:
150frames, SP1, Test
RC type: Q
              AVG_PSNR   OVR_PSNR   SSIM
lowres           -0.27      -0.30  -0.73
midres           -0.22      -0.22  -0.56
Hdres (95% done) -0.18      -0.20  -0.19

RC type: VBR
              AVG_PSNR   OVR_PSNR   SSIM
lowres           -0.28      -0.30  -0.69
midres           -0.26      -0.27  -0.49
Hdres (95% done) -0.45      -0.45  -0.44

Change-Id: I1d5a0a4cec4e0226122a0587bf5f95ee10f5b0ef
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index 4e587a7..fd11726 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -65,6 +65,7 @@
 #include "av1/encoder/reconinter_enc.h"
 #include "av1/encoder/segmentation.h"
 #include "av1/encoder/tokenize.h"
+#include "av1/encoder/tpl_model.h"
 #include "av1/encoder/var_based_part.h"
 #include "av1/encoder/tpl_model.h"
 
@@ -235,6 +236,7 @@
 
   assert(cpi->oxcf.tuning == AOM_TUNE_SSIM);
 
+  aom_clear_system_state();
   for (row = mi_row / num_mi_w;
        row < num_rows && row < mi_row / num_mi_w + num_brows; ++row) {
     for (col = mi_col / num_mi_h;
@@ -246,12 +248,64 @@
   }
   geom_mean_of_scale = exp(geom_mean_of_scale / num_of_mi);
 
-  *rdmult = (int)((double)(*rdmult) * geom_mean_of_scale);
+  *rdmult = (int)((double)(*rdmult) * geom_mean_of_scale + 0.5);
   *rdmult = AOMMAX(*rdmult, 0);
   set_error_per_bit(x, *rdmult);
   aom_clear_system_state();
 }
 
+static int get_hier_tpl_rdmult(const AV1_COMP *const cpi, MACROBLOCK *const x,
+                               const BLOCK_SIZE bsize, const int mi_row,
+                               const int mi_col, int orig_rdmult) {
+  const AV1_COMMON *const cm = &cpi->common;
+  assert(IMPLIES(cpi->gf_group.size > 0,
+                 cpi->gf_group.index < cpi->gf_group.size));
+  const int tpl_idx = cpi->gf_group.index;
+  const TplDepFrame *tpl_frame = &cpi->tpl_frame[tpl_idx];
+  MACROBLOCKD *const xd = &x->e_mbd;
+  const int deltaq_rdmult = set_deltaq_rdmult(cpi, xd);
+  if (cpi->tpl_model_pass == 1) {
+    assert(cpi->oxcf.enable_tpl_model == 2);
+    return deltaq_rdmult;
+  }
+  if (tpl_frame->is_valid == 0) return deltaq_rdmult;
+  if (!is_frame_tpl_eligible((AV1_COMP *)cpi)) return deltaq_rdmult;
+  if (tpl_idx >= MAX_LAG_BUFFERS) return deltaq_rdmult;
+  if (cpi->oxcf.superres_mode != SUPERRES_NONE) return deltaq_rdmult;
+
+  const int bsize_base = BLOCK_16X16;
+  const int num_mi_w = mi_size_wide[bsize_base];
+  const int num_mi_h = mi_size_high[bsize_base];
+  const int num_cols = (cm->mi_cols + num_mi_w - 1) / num_mi_w;
+  const int num_rows = (cm->mi_rows + num_mi_h - 1) / num_mi_h;
+  const int num_bcols = (mi_size_wide[bsize] + num_mi_w - 1) / num_mi_w;
+  const int num_brows = (mi_size_high[bsize] + num_mi_h - 1) / num_mi_h;
+  int row, col;
+  double base_block_count = 0.0;
+  double geom_mean_of_scale = 0.0;
+  aom_clear_system_state();
+  for (row = mi_row / num_mi_w;
+       row < num_rows && row < mi_row / num_mi_w + num_brows; ++row) {
+    for (col = mi_col / num_mi_h;
+         col < num_cols && col < mi_col / num_mi_h + num_bcols; ++col) {
+      const int index = row * num_cols + col;
+      geom_mean_of_scale += log(cpi->tpl_sb_rdmult_scaling_factors[index]);
+      base_block_count += 1.0;
+    }
+  }
+  geom_mean_of_scale = exp(geom_mean_of_scale / base_block_count);
+  int rdmult = (int)((double)orig_rdmult * geom_mean_of_scale + 0.5);
+  rdmult = AOMMAX(rdmult, 0);
+  set_error_per_bit(x, rdmult);
+  aom_clear_system_state();
+  if (bsize == cm->seq_params.sb_size) {
+    const int rdmult_sb = set_deltaq_rdmult(cpi, xd);
+    assert(rdmult_sb == rdmult);
+    (void)rdmult_sb;
+  }
+  return rdmult;
+}
+
 static int set_segment_rdmult(const AV1_COMP *const cpi, MACROBLOCK *const x,
                               int8_t segment_id) {
   const AV1_COMMON *const cm = &cpi->common;
@@ -287,8 +341,7 @@
 
   const AV1_COMMON *const cm = &cpi->common;
   if (cm->delta_q_info.delta_q_present_flag) {
-    MACROBLOCKD *const xd = &x->e_mbd;
-    x->rdmult = set_deltaq_rdmult(cpi, xd);
+    x->rdmult = get_hier_tpl_rdmult(cpi, x, bsize, mi_row, mi_col, x->rdmult);
   }
 
   if (cpi->oxcf.tuning == AOM_TUNE_SSIM) {
@@ -4030,8 +4083,10 @@
     xd->cur_frame_force_integer_mv = cm->cur_frame_force_integer_mv;
 
     x->sb_energy_level = 0;
-    if (cm->delta_q_info.delta_q_present_flag)
+    if (cm->delta_q_info.delta_q_present_flag) {
       setup_delta_q(cpi, td, x, tile_info, mi_row, mi_col, num_planes);
+      av1_tpl_rdmult_setup_sb(cpi, x, sb_size, mi_row, mi_col);
+    }
 
     td->mb.cb_coef_buff = av1_get_cb_coeff_buffer(cpi, mi_row, mi_col);
 
diff --git a/av1/encoder/encoder.c b/av1/encoder/encoder.c
index 1cca1a5..878e42e 100644
--- a/av1/encoder/encoder.c
+++ b/av1/encoder/encoder.c
@@ -555,6 +555,12 @@
   aom_free(cpi->ssim_rdmult_scaling_factors);
   cpi->ssim_rdmult_scaling_factors = NULL;
 
+  aom_free(cpi->tpl_rdmult_scaling_factors);
+  cpi->tpl_rdmult_scaling_factors = NULL;
+
+  aom_free(cpi->tpl_sb_rdmult_scaling_factors);
+  cpi->tpl_sb_rdmult_scaling_factors = NULL;
+
   aom_free(cpi->td.mb.above_pred_buf);
   cpi->td.mb.above_pred_buf = NULL;
 
@@ -2838,6 +2844,20 @@
     const int h = mi_size_high[bsize];
     const int num_cols = (cm->mi_cols + w - 1) / w;
     const int num_rows = (cm->mi_rows + h - 1) / h;
+    CHECK_MEM_ERROR(cm, cpi->tpl_rdmult_scaling_factors,
+                    aom_calloc(num_rows * num_cols,
+                               sizeof(*cpi->tpl_rdmult_scaling_factors)));
+    CHECK_MEM_ERROR(cm, cpi->tpl_sb_rdmult_scaling_factors,
+                    aom_calloc(num_rows * num_cols,
+                               sizeof(*cpi->tpl_sb_rdmult_scaling_factors)));
+  }
+
+  {
+    const int bsize = BLOCK_16X16;
+    const int w = mi_size_wide[bsize];
+    const int h = mi_size_high[bsize];
+    const int num_cols = (cm->mi_cols + w - 1) / w;
+    const int num_rows = (cm->mi_rows + h - 1) / h;
     CHECK_MEM_ERROR(cm, cpi->ssim_rdmult_scaling_factors,
                     aom_calloc(num_rows * num_cols,
                                sizeof(*cpi->ssim_rdmult_scaling_factors)));
@@ -3779,8 +3799,10 @@
   av1_set_speed_features_framesize_dependent(cpi, cpi->speed);
 
   if (cpi->oxcf.enable_tpl_model && cpi->tpl_model_pass == 0 &&
-      is_frame_tpl_eligible(cpi))
+      is_frame_tpl_eligible(cpi)) {
     process_tpl_stats_frame(cpi);
+    av1_tpl_rdmult_setup(cpi);
+  }
 
   // Decide q and q bounds.
   *q = av1_rc_pick_q_and_bounds(cpi, cm->width, cm->height, cpi->gf_group.index,
diff --git a/av1/encoder/encoder.h b/av1/encoder/encoder.h
index d511c51..53cc6a2 100644
--- a/av1/encoder/encoder.h
+++ b/av1/encoder/encoder.h
@@ -1058,6 +1058,9 @@
   int8_t nearest_past_ref;
   int8_t nearest_future_ref;
 
+  // TODO(sdeng): consider merge the following arrays.
+  double *tpl_rdmult_scaling_factors;
+  double *tpl_sb_rdmult_scaling_factors;
   double *ssim_rdmult_scaling_factors;
 
   int use_svc;
diff --git a/av1/encoder/tpl_model.c b/av1/encoder/tpl_model.c
index 9756328..1031aaa 100644
--- a/av1/encoder/tpl_model.c
+++ b/av1/encoder/tpl_model.c
@@ -16,6 +16,7 @@
 #include "config/aom_dsp_rtcd.h"
 
 #include "aom/aom_codec.h"
+#include "aom_ports/system_state.h"
 
 #include "av1/common/enums.h"
 #include "av1/common/onyxc_int.h"
@@ -1007,3 +1008,112 @@
     }
   }
 }
+
+void av1_tpl_rdmult_setup(AV1_COMP *cpi) {
+  const AV1_COMMON *const cm = &cpi->common;
+  const GF_GROUP *const gf_group = &cpi->gf_group;
+  const int tpl_idx = gf_group->index;
+
+  assert(IMPLIES(gf_group->size > 0, tpl_idx < gf_group->size));
+
+  const TplDepFrame *const tpl_frame = &cpi->tpl_frame[tpl_idx];
+  if (!tpl_frame->is_valid) return;
+  if (cpi->oxcf.superres_mode != SUPERRES_NONE) return;
+  const TplDepStats *const tpl_stats = tpl_frame->tpl_stats_ptr;
+  const int tpl_stride = tpl_frame->stride;
+
+  const int block_size = BLOCK_16X16;
+  const int num_mi_w = mi_size_wide[block_size];
+  const int num_mi_h = mi_size_high[block_size];
+  const int num_cols = (cm->mi_cols + num_mi_w - 1) / num_mi_w;
+  const int num_rows = (cm->mi_rows + num_mi_h - 1) / num_mi_h;
+  const double c = 1.2;
+  const int step = 1 << cpi->tpl_stats_block_mis_log2;
+
+  aom_clear_system_state();
+
+  // Loop through each 'block_size' X 'block_size' block.
+  for (int row = 0; row < num_rows; row++) {
+    for (int col = 0; col < num_cols; col++) {
+      double intra_cost = 0.0, mc_dep_cost = 0.0;
+      // Loop through each mi block.
+      for (int mi_row = row * num_mi_h; mi_row < (row + 1) * num_mi_h;
+           mi_row += step) {
+        for (int mi_col = col * num_mi_w; mi_col < (col + 1) * num_mi_w;
+             mi_col += step) {
+          if (mi_row >= cm->mi_rows || mi_col >= cm->mi_cols) continue;
+          const TplDepStats *this_stats =
+              &tpl_stats[av1_tpl_ptr_pos(cpi, mi_row, mi_col, tpl_stride)];
+          intra_cost += (double)this_stats->intra_cost;
+          mc_dep_cost += (double)this_stats->intra_cost + this_stats->mc_flow;
+        }
+      }
+      const double rk = intra_cost / mc_dep_cost;
+      const int index = row * num_cols + col;
+      cpi->tpl_rdmult_scaling_factors[index] = rk / cpi->rd.r0 + c;
+    }
+  }
+  aom_clear_system_state();
+}
+
+void av1_tpl_rdmult_setup_sb(AV1_COMP *cpi, MACROBLOCK *const x,
+                             BLOCK_SIZE sb_size, int mi_row, int mi_col) {
+  AV1_COMMON *const cm = &cpi->common;
+  assert(IMPLIES(cpi->gf_group.size > 0,
+                 cpi->gf_group.index < cpi->gf_group.size));
+  const int tpl_idx = cpi->gf_group.index;
+  TplDepFrame *tpl_frame = &cpi->tpl_frame[tpl_idx];
+
+  if (cpi->tpl_model_pass == 1) {
+    assert(cpi->oxcf.enable_tpl_model == 2);
+    return;
+  }
+  if (tpl_frame->is_valid == 0) return;
+  if (!is_frame_tpl_eligible(cpi)) return;
+  if (tpl_idx >= MAX_LAG_BUFFERS) return;
+  if (cpi->oxcf.superres_mode != SUPERRES_NONE) return;
+
+  const int bsize_base = BLOCK_16X16;
+  const int num_mi_w = mi_size_wide[bsize_base];
+  const int num_mi_h = mi_size_high[bsize_base];
+  const int num_cols = (cm->mi_cols + num_mi_w - 1) / num_mi_w;
+  const int num_rows = (cm->mi_rows + num_mi_h - 1) / num_mi_h;
+  const int num_bcols = (mi_size_wide[sb_size] + num_mi_w - 1) / num_mi_w;
+  const int num_brows = (mi_size_high[sb_size] + num_mi_h - 1) / num_mi_h;
+  int row, col;
+
+  double base_block_count = 0.0;
+  double log_sum = 0.0;
+
+  aom_clear_system_state();
+  for (row = mi_row / num_mi_w;
+       row < num_rows && row < mi_row / num_mi_w + num_brows; ++row) {
+    for (col = mi_col / num_mi_h;
+         col < num_cols && col < mi_col / num_mi_h + num_bcols; ++col) {
+      const int index = row * num_cols + col;
+      log_sum += log(cpi->tpl_rdmult_scaling_factors[index]);
+      base_block_count += 1.0;
+    }
+  }
+
+  MACROBLOCKD *const xd = &x->e_mbd;
+  const int orig_rdmult =
+      av1_compute_rd_mult(cpi, cm->base_qindex + cm->y_dc_delta_q);
+  const int new_rdmult = av1_compute_rd_mult(
+      cpi, cm->base_qindex + xd->delta_qindex + cm->y_dc_delta_q);
+  const double scaling_factor = (double)new_rdmult / (double)orig_rdmult;
+
+  double scale_adj = log(scaling_factor) - log_sum / base_block_count;
+  scale_adj = exp(scale_adj);
+
+  for (row = mi_row / num_mi_w;
+       row < num_rows && row < mi_row / num_mi_w + num_brows; ++row) {
+    for (col = mi_col / num_mi_h;
+         col < num_cols && col < mi_col / num_mi_h + num_bcols; ++col) {
+      const int index = row * num_cols + col;
+      cpi->tpl_sb_rdmult_scaling_factors[index] =
+          scale_adj * cpi->tpl_rdmult_scaling_factors[index];
+    }
+  }
+  aom_clear_system_state();
+}
diff --git a/av1/encoder/tpl_model.h b/av1/encoder/tpl_model.h
index e3b1e68..ba74543 100644
--- a/av1/encoder/tpl_model.h
+++ b/av1/encoder/tpl_model.h
@@ -24,6 +24,11 @@
 
 int av1_tpl_ptr_pos(AV1_COMP *cpi, int mi_row, int mi_col, int stride);
 
+void av1_tpl_rdmult_setup(AV1_COMP *cpi);
+
+void av1_tpl_rdmult_setup_sb(AV1_COMP *cpi, MACROBLOCK *const x,
+                             BLOCK_SIZE sb_size, int mi_row, int mi_col);
+
 #ifdef __cplusplus
 }  // extern "C"
 #endif