Factor out common operations in tpl model

With a bug fix in the common functions.

STATS_CHANGED

Change-Id: I17fb12a82c8255166fb101efd307999dc44a89b7
diff --git a/av1/encoder/tpl_model.c b/av1/encoder/tpl_model.c
index 55a044d..821cbea 100644
--- a/av1/encoder/tpl_model.c
+++ b/av1/encoder/tpl_model.c
@@ -223,6 +223,93 @@
   return 0;
 }
 
+static void get_rate_distortion(
+    int *rate_cost, int64_t *recon_error, int16_t *src_diff, tran_low_t *coeff,
+    tran_low_t *qcoeff, tran_low_t *dqcoeff, AV1_COMMON *cm, MACROBLOCK *x,
+    const YV12_BUFFER_CONFIG *ref_frame_ptr, uint8_t *rec_buffer_pool[3],
+    int rec_stride_pool[3], TX_SIZE tx_size, PREDICTION_MODE best_mode,
+    int_mv best_mv, int mi_row, int mi_col) {
+  *rate_cost = 0;
+  *recon_error = 1;
+
+  MACROBLOCKD *xd = &x->e_mbd;
+
+  uint8_t *src_buffer_pool[MAX_MB_PLANE] = {
+    xd->cur_buf->y_buffer,
+    xd->cur_buf->u_buffer,
+    xd->cur_buf->v_buffer,
+  };
+  int src_stride_pool[MAX_MB_PLANE] = {
+    xd->cur_buf->y_stride,
+    xd->cur_buf->uv_stride,
+    xd->cur_buf->uv_stride,
+  };
+
+  const int_interpfilters kernel =
+      av1_broadcast_interp_filter(EIGHTTAP_REGULAR);
+
+  for (int plane = 0; plane < MAX_MB_PLANE; ++plane) {
+    struct macroblockd_plane *pd = &xd->plane[plane];
+
+    BLOCK_SIZE bsize_plane =
+        ss_size_lookup[txsize_to_bsize[tx_size]][pd->subsampling_x]
+                      [pd->subsampling_y];
+
+    int dst_buffer_stride = rec_stride_pool[plane];
+    int dst_mb_offset =
+        (mi_row * MI_SIZE * dst_buffer_stride + mi_col * MI_SIZE) >>
+        pd->subsampling_x;
+    uint8_t *dst_buffer = rec_buffer_pool[plane] + dst_mb_offset;
+    if (!is_inter_mode(best_mode)) {
+      av1_predict_intra_block(
+          cm, xd, block_size_wide[bsize_plane], block_size_high[bsize_plane],
+          max_txsize_rect_lookup[bsize_plane], best_mode, 0, 0,
+          FILTER_INTRA_MODES, dst_buffer, dst_buffer_stride, dst_buffer,
+          dst_buffer_stride, 0, 0, plane);
+    } else {
+      uint8_t *ref_buffer_pool[MAX_MB_PLANE] = {
+        ref_frame_ptr->y_buffer,
+        ref_frame_ptr->u_buffer,
+        ref_frame_ptr->v_buffer,
+      };
+      InterPredParams inter_pred_params;
+      struct buf_2d ref_buf = {
+        NULL, ref_buffer_pool[plane],
+        plane ? ref_frame_ptr->uv_width : ref_frame_ptr->y_width,
+        plane ? ref_frame_ptr->uv_height : ref_frame_ptr->y_height,
+        plane ? ref_frame_ptr->uv_stride : ref_frame_ptr->y_stride
+      };
+      av1_init_inter_params(
+          &inter_pred_params, block_size_wide[bsize_plane],
+          block_size_high[bsize_plane], (mi_row * MI_SIZE) >> pd->subsampling_y,
+          (mi_col * MI_SIZE) >> pd->subsampling_x, pd->subsampling_x,
+          pd->subsampling_y, xd->bd, is_cur_buf_hbd(xd), 0,
+          xd->block_ref_scale_factors[0], &ref_buf, kernel);
+      inter_pred_params.conv_params = get_conv_params(0, plane, xd->bd);
+
+      av1_enc_build_one_inter_predictor(dst_buffer, dst_buffer_stride,
+                                        &best_mv.as_mv, &inter_pred_params);
+    }
+
+    int src_stride = src_stride_pool[plane];
+    int src_mb_offset =
+        (mi_row * MI_SIZE * src_stride + mi_col * MI_SIZE) >> pd->subsampling_x;
+
+    int this_rate = 1;
+    int64_t this_recon_error = 1;
+    int64_t sse;
+    txfm_quant_rdcost(
+        x, src_diff, block_size_wide[bsize_plane],
+        src_buffer_pool[plane] + src_mb_offset, src_stride, dst_buffer,
+        dst_buffer_stride, coeff, qcoeff, dqcoeff, block_size_wide[bsize_plane],
+        block_size_high[bsize_plane], max_txsize_rect_lookup[bsize_plane],
+        &this_rate, &this_recon_error, &sse);
+
+    *recon_error += this_recon_error;
+    *rate_cost += this_rate;
+  }
+}
+
 static AOM_INLINE void mode_estimation(AV1_COMP *cpi, MACROBLOCK *x, int mi_row,
                                        int mi_col, BLOCK_SIZE bsize,
                                        TX_SIZE tx_size,
@@ -250,17 +337,6 @@
   uint8_t *src_mb_buffer = xd->cur_buf->y_buffer + mb_y_offset;
   int src_stride = xd->cur_buf->y_stride;
 
-  uint8_t *src_buffer_pool[MAX_MB_PLANE] = {
-    xd->cur_buf->y_buffer,
-    xd->cur_buf->u_buffer,
-    xd->cur_buf->v_buffer,
-  };
-  int src_stride_pool[MAX_MB_PLANE] = {
-    xd->cur_buf->y_stride,
-    xd->cur_buf->uv_stride,
-    xd->cur_buf->uv_stride,
-  };
-
   int dst_mb_offset =
       mi_row * MI_SIZE * tpl_frame->rec_picture->y_stride + mi_col * MI_SIZE;
   uint8_t *dst_buffer = tpl_frame->rec_picture->y_buffer + dst_mb_offset;
@@ -296,7 +372,7 @@
       aom_memalign(32, tpl_block_pels * sizeof(tran_low_t));
   uint8_t *predictor =
       is_cur_buf_hbd(xd) ? CONVERT_TO_BYTEPTR(predictor8) : predictor8;
-  int64_t recon_error = 1, sse = 1;
+  int64_t recon_error = 1;
 
   memset(tpl_stats, 0, sizeof(*tpl_stats));
 
@@ -494,54 +570,11 @@
   if (best_inter_cost < INT64_MAX) {
     const YV12_BUFFER_CONFIG *ref_frame_ptr =
         tpl_data->src_ref_frame[best_rf_idx];
-
-    uint8_t *ref_buffer_pool[MAX_MB_PLANE] = {
-      ref_frame_ptr->y_buffer,
-      ref_frame_ptr->u_buffer,
-      ref_frame_ptr->v_buffer,
-    };
-
     int rate_cost = 1;
-
-    for (int plane = 0; plane < MAX_MB_PLANE; ++plane) {
-      struct macroblockd_plane *pd = &xd->plane[plane];
-      BLOCK_SIZE bsize_plane =
-          ss_size_lookup[txsize_to_bsize[tx_size]][pd->subsampling_x]
-                        [pd->subsampling_y];
-      InterPredParams inter_pred_params;
-      struct buf_2d ref_buf = {
-        NULL, ref_buffer_pool[plane],
-        plane ? ref_frame_ptr->uv_width : ref_frame_ptr->y_width,
-        plane ? ref_frame_ptr->uv_height : ref_frame_ptr->y_height,
-        plane ? ref_frame_ptr->uv_stride : ref_frame_ptr->y_stride
-      };
-      av1_init_inter_params(
-          &inter_pred_params, bw >> pd->subsampling_x, bh >> pd->subsampling_y,
-          (mi_row * MI_SIZE) >> pd->subsampling_y,
-          (mi_col * MI_SIZE) >> pd->subsampling_x, pd->subsampling_x,
-          pd->subsampling_y, xd->bd, is_cur_buf_hbd(xd), 0, &tpl_data->sf,
-          &ref_buf, kernel);
-      inter_pred_params.conv_params = get_conv_params(0, plane, xd->bd);
-
-      av1_enc_build_one_inter_predictor(predictor, bw >> pd->subsampling_y,
-                                        &best_mv.as_mv, &inter_pred_params);
-
-      src_stride = src_stride_pool[plane];
-      int src_mb_offset = (mi_row * MI_SIZE * src_stride + mi_col * MI_SIZE) >>
-                          pd->subsampling_x;
-
-      int this_rate = 1;
-      int64_t this_recon_error = 1;
-      txfm_quant_rdcost(x, src_diff, bw >> pd->subsampling_x,
-                        src_buffer_pool[plane] + src_mb_offset, src_stride,
-                        predictor, bw >> pd->subsampling_x, coeff, qcoeff,
-                        dqcoeff, bw >> pd->subsampling_x,
-                        bh >> pd->subsampling_x,
-                        max_txsize_rect_lookup[bsize_plane], &this_rate,
-                        &this_recon_error, &sse);
-      rate_cost += this_recon_error;
-      recon_error += this_recon_error;
-    }
+    get_rate_distortion(&rate_cost, &recon_error, src_diff, coeff, qcoeff,
+                        dqcoeff, cm, x, ref_frame_ptr, rec_buffer_pool,
+                        rec_stride_pool, tx_size, best_mode, best_mv, mi_row,
+                        mi_col);
     tpl_stats->srcrf_rate = rate_cost << TPL_DEP_COST_SCALE_LOG2;
   }
 
@@ -554,69 +587,11 @@
 
   // Final encode
   int rate_cost = 0;
-  recon_error = 1;
-
-  for (int plane = 0; plane < MAX_MB_PLANE; ++plane) {
-    struct macroblockd_plane *pd = &xd->plane[plane];
-
-    BLOCK_SIZE bsize_plane =
-        ss_size_lookup[txsize_to_bsize[tx_size]][pd->subsampling_x]
-                      [pd->subsampling_y];
-
-    dst_buffer_stride = rec_stride_pool[plane];
-    dst_mb_offset = (mi_row * MI_SIZE * dst_buffer_stride + mi_col * MI_SIZE) >>
-                    pd->subsampling_x;
-    dst_buffer = rec_buffer_pool[plane] + dst_mb_offset;
-    if (!is_inter_mode(best_mode)) {
-      av1_predict_intra_block(
-          cm, xd, block_size_wide[bsize_plane], block_size_high[bsize_plane],
-          max_txsize_rect_lookup[bsize_plane], best_mode, 0, 0,
-          FILTER_INTRA_MODES, dst_buffer, dst_buffer_stride, dst_buffer,
-          dst_buffer_stride, 0, 0, plane);
-    } else {
-      const YV12_BUFFER_CONFIG *ref_frame_ptr =
-          tpl_data->ref_frame[best_rf_idx];
-
-      uint8_t *ref_buffer_pool[MAX_MB_PLANE] = {
-        ref_frame_ptr->y_buffer,
-        ref_frame_ptr->u_buffer,
-        ref_frame_ptr->v_buffer,
-      };
-      InterPredParams inter_pred_params;
-      struct buf_2d ref_buf = {
-        NULL, ref_buffer_pool[plane],
-        plane ? ref_frame_ptr->uv_width : ref_frame_ptr->y_width,
-        plane ? ref_frame_ptr->uv_height : ref_frame_ptr->y_height,
-        plane ? ref_frame_ptr->uv_stride : ref_frame_ptr->y_stride
-      };
-      av1_init_inter_params(
-          &inter_pred_params, bw >> pd->subsampling_x, bh >> pd->subsampling_y,
-          (mi_row * MI_SIZE) >> pd->subsampling_y,
-          (mi_col * MI_SIZE) >> pd->subsampling_x, pd->subsampling_x,
-          pd->subsampling_y, xd->bd, is_cur_buf_hbd(xd), 0, &tpl_data->sf,
-          &ref_buf, kernel);
-      inter_pred_params.conv_params = get_conv_params(0, plane, xd->bd);
-
-      av1_enc_build_one_inter_predictor(dst_buffer, dst_buffer_stride,
-                                        &best_mv.as_mv, &inter_pred_params);
-    }
-
-    src_stride = src_stride_pool[plane];
-    int src_mb_offset =
-        (mi_row * MI_SIZE * src_stride + mi_col * MI_SIZE) >> pd->subsampling_x;
-
-    int this_rate = 1;
-    int64_t this_recon_error = 1;
-    txfm_quant_rdcost(x, src_diff, bw >> pd->subsampling_x,
-                      src_buffer_pool[plane] + src_mb_offset, src_stride,
-                      dst_buffer, dst_buffer_stride, coeff, qcoeff, dqcoeff,
-                      bw >> pd->subsampling_x, bh >> pd->subsampling_x,
-                      max_txsize_rect_lookup[bsize_plane], &this_rate,
-                      &this_recon_error, &sse);
-
-    recon_error += this_recon_error;
-    rate_cost += this_rate;
-  }
+  get_rate_distortion(
+      &rate_cost, &recon_error, src_diff, coeff, qcoeff, dqcoeff, cm, x,
+      best_rf_idx >= 0 ? tpl_data->ref_frame[best_rf_idx] : NULL,
+      rec_buffer_pool, rec_stride_pool, tx_size, best_mode, best_mv, mi_row,
+      mi_col);
 
   tpl_stats->recrf_dist = recon_error << (TPL_DEP_COST_SCALE_LOG2);
   tpl_stats->recrf_rate = rate_cost << TPL_DEP_COST_SCALE_LOG2;
@@ -947,6 +922,7 @@
   xd->mi = &mbmi_ptr;
 
   xd->block_ref_scale_factors[0] = &tpl_data->sf;
+  xd->block_ref_scale_factors[1] = &tpl_data->sf;
 
   const int base_qindex = pframe_qindex;
   // Get rd multiplier set up.