Rework tpl model kernel

Make the model propagation kernel account for the rate-distortion
cost through the motion compensated dependency.

STATS_CHANGED

Change-Id: I36d27e280acfe85ccdcd9f44c48512e7ddcd69a4
diff --git a/av1/encoder/encoder.h b/av1/encoder/encoder.h
index 58422c4..c6e2274 100644
--- a/av1/encoder/encoder.h
+++ b/av1/encoder/encoder.h
@@ -169,6 +169,8 @@
   int64_t srcrf_dist;
   int64_t recrf_dist;
   int64_t mc_dep_delta;
+  int64_t src_rdcost;
+  int64_t rec_rdcost;
   int_mv mv;
   int ref_frame_index;
   double quant_ratio;
diff --git a/av1/encoder/ratectrl.h b/av1/encoder/ratectrl.h
index 83a72be..0c92af6 100644
--- a/av1/encoder/ratectrl.h
+++ b/av1/encoder/ratectrl.h
@@ -156,6 +156,7 @@
   // Q index used for ALT frame
   int arf_q;
   int active_worst_quality;
+  int base_layer_qp;
 } RATE_CONTROL;
 
 struct AV1_COMP;
diff --git a/av1/encoder/tpl_model.c b/av1/encoder/tpl_model.c
index cfed647..e7d2269 100644
--- a/av1/encoder/tpl_model.c
+++ b/av1/encoder/tpl_model.c
@@ -66,6 +66,21 @@
   av1_fwd_txfm(src_diff, coeff, bw, &txfm_param);
 }
 
+static int rate_estimator(tran_low_t *qcoeff, int eob, TX_SIZE tx_size) {
+  const SCAN_ORDER *const scan_order = &av1_default_scan_orders[tx_size];
+
+  assert((1 << num_pels_log2_lookup[txsize_to_bsize[tx_size]]) >= eob);
+
+  int rate_cost = 1;
+
+  for (int idx = 0; idx < eob; ++idx) {
+    int abs_level = abs(qcoeff[scan_order->scan[idx]]);
+    rate_cost += (int)(log(abs_level + 1.0) / log(2.0)) + 1;
+  }
+
+  return (rate_cost << AV1_PROB_COST_SHIFT);
+}
+
 static uint32_t motion_estimation(AV1_COMP *cpi, MACROBLOCK *x,
                                   uint8_t *cur_frame_buf,
                                   uint8_t *ref_frame_buf, int stride,
@@ -127,7 +142,8 @@
     tran_low_t *dqcoeff, int mi_row, int mi_col, BLOCK_SIZE bsize,
     TX_SIZE tx_size, const YV12_BUFFER_CONFIG *ref_frame[],
     const YV12_BUFFER_CONFIG *src_ref_frame[], uint8_t *predictor,
-    int64_t *recon_error, int64_t *sse, TplDepStats *tpl_stats) {
+    int base_rdmult, int64_t *recon_error, int64_t *sse,
+    TplDepStats *tpl_stats) {
   AV1_COMMON *cm = &cpi->common;
   const GF_GROUP *gf_group = &cpi->gf_group;
 
@@ -145,6 +161,7 @@
 
   int64_t best_intra_cost = INT64_MAX;
   int64_t intra_cost;
+  int64_t best_rdcost = 0;
   PREDICTION_MODE mode;
   PREDICTION_MODE best_mode = DC_PRED;
 
@@ -268,6 +285,8 @@
       best_mv.as_int = x->best_mv.as_int;
       get_quantize_error(x, 0, coeff, qcoeff, dqcoeff, tx_size, &eob,
                          recon_error, sse);
+      int rate_cost = rate_estimator(qcoeff, eob, tx_size);
+      best_rdcost = RDCOST(base_rdmult, rate_cost, *recon_error);
     }
   }
   best_intra_cost = AOMMAX(best_intra_cost, 1);
@@ -279,7 +298,9 @@
   tpl_stats->inter_cost = best_inter_cost << TPL_DEP_COST_SCALE_LOG2;
   tpl_stats->intra_cost = best_intra_cost << TPL_DEP_COST_SCALE_LOG2;
 
-  tpl_stats->srcrf_dist = *recon_error << TPL_DEP_COST_SCALE_LOG2;
+  tpl_stats->srcrf_dist = *recon_error
+                          << (TPL_DEP_COST_SCALE_LOG2 + RDDIV_BITS);
+  tpl_stats->src_rdcost = best_rdcost << TPL_DEP_COST_SCALE_LOG2;
 
   // Final encode
   if (is_inter_mode(best_mode)) {
@@ -324,14 +345,23 @@
   get_quantize_error(x, 0, coeff, qcoeff, dqcoeff, tx_size, &eob, recon_error,
                      sse);
 
+  int rate_cost = rate_estimator(qcoeff, eob, tx_size);
+
   av1_inverse_transform_block(xd, dqcoeff, 0, DCT_DCT, tx_size, dst_buffer,
                               dst_buffer_stride, eob, 0);
 
-  tpl_stats->recrf_dist = *recon_error << TPL_DEP_COST_SCALE_LOG2;
-  if (!is_inter_mode(best_mode))
-    tpl_stats->srcrf_dist = *recon_error << TPL_DEP_COST_SCALE_LOG2;
-
+  tpl_stats->recrf_dist = *recon_error
+                          << (TPL_DEP_COST_SCALE_LOG2 + RDDIV_BITS);
+  tpl_stats->rec_rdcost = RDCOST(base_rdmult, rate_cost, *recon_error)
+                          << TPL_DEP_COST_SCALE_LOG2;
+  if (!is_inter_mode(best_mode)) {
+    tpl_stats->srcrf_dist = *recon_error
+                            << (TPL_DEP_COST_SCALE_LOG2 + RDDIV_BITS);
+    tpl_stats->src_rdcost = RDCOST(base_rdmult, rate_cost, *recon_error)
+                            << TPL_DEP_COST_SCALE_LOG2;
+  }
   tpl_stats->recrf_dist = AOMMAX(tpl_stats->srcrf_dist, tpl_stats->recrf_dist);
+  tpl_stats->rec_rdcost = AOMMAX(tpl_stats->rec_rdcost, tpl_stats->src_rdcost);
 
   if (frame_idx && best_rf_idx != -1) {
     tpl_stats->mv.as_int = best_mv.as_int;
@@ -433,11 +463,11 @@
                     (1.0 - iiratio_nl));
 
       int64_t cur_dep_cost =
-          tpl_stats_ptr->recrf_dist - tpl_stats_ptr->srcrf_dist;
-      int64_t mc_dep_delta =
-          (tpl_stats_ptr->mc_dep_delta *
-           (tpl_stats_ptr->recrf_dist - tpl_stats_ptr->srcrf_dist)) /
-          tpl_stats_ptr->recrf_dist;
+          tpl_stats_ptr->rec_rdcost - tpl_stats_ptr->src_rdcost;
+      int64_t mc_dep_delta = (int64_t)(
+          tpl_stats_ptr->mc_dep_delta *
+          ((double)(tpl_stats_ptr->recrf_dist - tpl_stats_ptr->srcrf_dist) /
+           tpl_stats_ptr->recrf_dist));
 
 #if !USE_TPL_CLASSIC_MODEL
       int64_t mc_saved = tpl_stats_ptr->intra_cost - tpl_stats_ptr->inter_cost;
@@ -493,11 +523,15 @@
   int64_t inter_cost = src_stats->inter_cost / (mi_height * mi_width);
   int64_t srcrf_dist = src_stats->srcrf_dist / (mi_height * mi_width);
   int64_t recrf_dist = src_stats->recrf_dist / (mi_height * mi_width);
+  int64_t src_rdcost = src_stats->src_rdcost / (mi_height * mi_width);
+  int64_t rec_rdcost = src_stats->rec_rdcost / (mi_height * mi_width);
 
   intra_cost = AOMMAX(1, intra_cost);
   inter_cost = AOMMAX(1, inter_cost);
   srcrf_dist = AOMMAX(1, srcrf_dist);
   recrf_dist = AOMMAX(1, recrf_dist);
+  src_rdcost = AOMMAX(1, src_rdcost);
+  rec_rdcost = AOMMAX(1, rec_rdcost);
 
   for (int idy = 0; idy < mi_height; idy += step) {
     TplDepStats *tpl_ptr =
@@ -507,6 +541,8 @@
       tpl_ptr->inter_cost = inter_cost;
       tpl_ptr->srcrf_dist = srcrf_dist;
       tpl_ptr->recrf_dist = recrf_dist;
+      tpl_ptr->src_rdcost = src_rdcost;
+      tpl_ptr->rec_rdcost = rec_rdcost;
       tpl_ptr->quant_ratio = src_stats->quant_ratio;
       tpl_ptr->mv.as_int = src_stats->mv.as_int;
       tpl_ptr->ref_frame_index = src_stats->ref_frame_index;
@@ -531,7 +567,8 @@
   }
 }
 
-static AOM_INLINE void mc_flow_dispenser(AV1_COMP *cpi, int frame_idx) {
+static AOM_INLINE void mc_flow_dispenser(AV1_COMP *cpi, int frame_idx,
+                                         int pframe_qindex) {
   const GF_GROUP *gf_group = &cpi->gf_group;
   if (frame_idx == gf_group->size) return;
   TplDepFrame *tpl_frame = &cpi->tpl_frame[frame_idx];
@@ -618,7 +655,7 @@
 
   xd->block_ref_scale_factors[0] = &sf;
 
-  const int base_qindex = gf_group->q_val[frame_idx];
+  const int base_qindex = pframe_qindex;
   // Get rd multiplier set up.
   rdmult = (int)av1_compute_rd_mult(cpi, base_qindex);
   if (rdmult < 1) rdmult = 1;
@@ -630,6 +667,8 @@
   cm->base_qindex = base_qindex;
   av1_frame_init_quantizer(cpi);
 
+  int base_rdmult = av1_compute_rd_mult_based_on_qindex(cpi, pframe_qindex) / 6;
+
   for (mi_row = 0; mi_row < cm->mi_rows; mi_row += mi_height) {
     // Motion estimation row boundary
     x->mv_limits.row_min = -((mi_row * MI_SIZE) + (17 - 2 * AOM_INTERP_EXTEND));
@@ -649,7 +688,8 @@
       xd->mb_to_right_edge = ((cm->mi_cols - mi_width - mi_col) * MI_SIZE) * 8;
       mode_estimation(cpi, x, xd, &sf, frame_idx, src_diff, coeff, qcoeff,
                       dqcoeff, mi_row, mi_col, bsize, tx_size, ref_frame,
-                      src_frame, predictor, &recon_error, &sse, &tpl_stats);
+                      src_frame, predictor, base_rdmult, &recon_error, &sse,
+                      &tpl_stats);
 
       // Motion flow dependency dispenser.
       double quant_ratio = (double)recon_error / sse;
@@ -688,13 +728,13 @@
 static AOM_INLINE void init_gop_frames_for_tpl(
     AV1_COMP *cpi, const EncodeFrameParams *const init_frame_params,
     GF_GROUP *gf_group, int *tpl_group_frames,
-    const EncodeFrameInput *const frame_input) {
+    const EncodeFrameInput *const frame_input, int *pframe_qindex) {
   AV1_COMMON *cm = &cpi->common;
   const SequenceHeader *const seq_params = &cm->seq_params;
   int frame_idx = 0;
   RefCntBuffer *frame_bufs = cm->buffer_pool->frame_bufs;
-  int pframe_qindex = 0;
   int cur_frame_idx = gf_group->index;
+  *pframe_qindex = 0;
 
   RefBufferStack ref_buffer_stack = cpi->ref_buffer_stack;
   EncodeFrameParams frame_params = *init_frame_params;
@@ -749,7 +789,7 @@
         frame_update_type == KF_UPDATE ? KEY_FRAME : INTER_FRAME;
 
     if (frame_update_type == LF_UPDATE)
-      pframe_qindex = gf_group->q_val[gf_index];
+      *pframe_qindex = gf_group->q_val[gf_index];
 
     if (gf_index == cur_frame_idx) {
       tpl_frame->gf_picture = frame_input->source;
@@ -818,7 +858,7 @@
         frame_display_index + cpi->common.current_frame.display_order_hint;
 
     gf_group->update_type[gf_index] = LF_UPDATE;
-    gf_group->q_val[gf_index] = pframe_qindex;
+    gf_group->q_val[gf_index] = *pframe_qindex;
 
     av1_get_ref_frames(cpi, &ref_buffer_stack);
     int refresh_mask = av1_get_refresh_frame_flags(
@@ -875,8 +915,12 @@
     cm->current_frame.frame_type = INTER_FRAME;
   }
 
+  int pframe_qindex;
   init_gop_frames_for_tpl(cpi, frame_params, gf_group,
-                          &cpi->tpl_gf_group_frames, frame_input);
+                          &cpi->tpl_gf_group_frames, frame_input,
+                          &pframe_qindex);
+
+  cpi->rc.base_layer_qp = pframe_qindex;
 
   init_tpl_stats(cpi);
 
@@ -888,7 +932,7 @@
           gf_group->update_type[frame_idx] == OVERLAY_UPDATE)
         continue;
 
-      mc_flow_dispenser(cpi, frame_idx);
+      mc_flow_dispenser(cpi, frame_idx, pframe_qindex);
 
       aom_extend_frame_borders(cpi->tpl_frame[frame_idx].rec_picture,
                                av1_num_planes(cm));