Rework tpl model estimation

Compute YUV components for overall model estimation.

STATS_CHANGED

Change-Id: I936a323bba57bf47cc006b4f5905822c25b87807
diff --git a/av1/encoder/tpl_model.c b/av1/encoder/tpl_model.c
index af7b95c..55a044d 100644
--- a/av1/encoder/tpl_model.c
+++ b/av1/encoder/tpl_model.c
@@ -248,12 +248,41 @@
 
   int mb_y_offset = mi_row * MI_SIZE * xd->cur_buf->y_stride + mi_col * MI_SIZE;
   uint8_t *src_mb_buffer = xd->cur_buf->y_buffer + mb_y_offset;
-  const int src_stride = xd->cur_buf->y_stride;
+  int src_stride = xd->cur_buf->y_stride;
 
-  const int dst_mb_offset =
+  uint8_t *src_buffer_pool[MAX_MB_PLANE] = {
+    xd->cur_buf->y_buffer,
+    xd->cur_buf->u_buffer,
+    xd->cur_buf->v_buffer,
+  };
+  int src_stride_pool[MAX_MB_PLANE] = {
+    xd->cur_buf->y_stride,
+    xd->cur_buf->uv_stride,
+    xd->cur_buf->uv_stride,
+  };
+
+  int dst_mb_offset =
       mi_row * MI_SIZE * tpl_frame->rec_picture->y_stride + mi_col * MI_SIZE;
   uint8_t *dst_buffer = tpl_frame->rec_picture->y_buffer + dst_mb_offset;
-  const int dst_buffer_stride = tpl_frame->rec_picture->y_stride;
+  int dst_buffer_stride = tpl_frame->rec_picture->y_stride;
+
+  uint8_t *rec_buffer_pool[3] = {
+    tpl_frame->rec_picture->y_buffer,
+    tpl_frame->rec_picture->u_buffer,
+    tpl_frame->rec_picture->v_buffer,
+  };
+
+  int rec_stride_pool[3] = {
+    tpl_frame->rec_picture->y_stride,
+    tpl_frame->rec_picture->uv_stride,
+    tpl_frame->rec_picture->uv_stride,
+  };
+
+  for (int plane = 1; plane < MAX_MB_PLANE; ++plane) {
+    struct macroblockd_plane *pd = &xd->plane[plane];
+    pd->subsampling_x = xd->cur_buf->subsampling_x;
+    pd->subsampling_y = xd->cur_buf->subsampling_y;
+  }
 
   // Number of pixels in a tpl block
   const int tpl_block_pels = tpl_data->tpl_bsize_1d * tpl_data->tpl_bsize_1d;
@@ -463,11 +492,56 @@
   }
 
   if (best_inter_cost < INT64_MAX) {
-    uint16_t eob;
-    get_quantize_error(x, 0, best_coeff, qcoeff, dqcoeff, tx_size, &eob,
-                       &recon_error, &sse);
+    const YV12_BUFFER_CONFIG *ref_frame_ptr =
+        tpl_data->src_ref_frame[best_rf_idx];
 
-    const int rate_cost = rate_estimator(qcoeff, eob, tx_size);
+    uint8_t *ref_buffer_pool[MAX_MB_PLANE] = {
+      ref_frame_ptr->y_buffer,
+      ref_frame_ptr->u_buffer,
+      ref_frame_ptr->v_buffer,
+    };
+
+    int rate_cost = 1;
+
+    for (int plane = 0; plane < MAX_MB_PLANE; ++plane) {
+      struct macroblockd_plane *pd = &xd->plane[plane];
+      BLOCK_SIZE bsize_plane =
+          ss_size_lookup[txsize_to_bsize[tx_size]][pd->subsampling_x]
+                        [pd->subsampling_y];
+      InterPredParams inter_pred_params;
+      struct buf_2d ref_buf = {
+        NULL, ref_buffer_pool[plane],
+        plane ? ref_frame_ptr->uv_width : ref_frame_ptr->y_width,
+        plane ? ref_frame_ptr->uv_height : ref_frame_ptr->y_height,
+        plane ? ref_frame_ptr->uv_stride : ref_frame_ptr->y_stride
+      };
+      av1_init_inter_params(
+          &inter_pred_params, bw >> pd->subsampling_x, bh >> pd->subsampling_y,
+          (mi_row * MI_SIZE) >> pd->subsampling_y,
+          (mi_col * MI_SIZE) >> pd->subsampling_x, pd->subsampling_x,
+          pd->subsampling_y, xd->bd, is_cur_buf_hbd(xd), 0, &tpl_data->sf,
+          &ref_buf, kernel);
+      inter_pred_params.conv_params = get_conv_params(0, plane, xd->bd);
+
+      av1_enc_build_one_inter_predictor(predictor, bw >> pd->subsampling_y,
+                                        &best_mv.as_mv, &inter_pred_params);
+
+      src_stride = src_stride_pool[plane];
+      int src_mb_offset = (mi_row * MI_SIZE * src_stride + mi_col * MI_SIZE) >>
+                          pd->subsampling_x;
+
+      int this_rate = 1;
+      int64_t this_recon_error = 1;
+      txfm_quant_rdcost(x, src_diff, bw >> pd->subsampling_x,
+                        src_buffer_pool[plane] + src_mb_offset, src_stride,
+                        predictor, bw >> pd->subsampling_x, coeff, qcoeff,
+                        dqcoeff, bw >> pd->subsampling_x,
+                        bh >> pd->subsampling_x,
+                        max_txsize_rect_lookup[bsize_plane], &this_rate,
+                        &this_recon_error, &sse);
+      rate_cost += this_recon_error;
+      recon_error += this_recon_error;
+    }
     tpl_stats->srcrf_rate = rate_cost << TPL_DEP_COST_SCALE_LOG2;
   }
 
@@ -479,32 +553,71 @@
   tpl_stats->srcrf_dist = recon_error << (TPL_DEP_COST_SCALE_LOG2);
 
   // Final encode
-  if (!is_inter_mode(best_mode)) {
-    av1_predict_intra_block(cm, xd, block_size_wide[bsize],
-                            block_size_high[bsize], tx_size, best_mode, 0, 0,
-                            FILTER_INTRA_MODES, dst_buffer, dst_buffer_stride,
-                            dst_buffer, dst_buffer_stride, 0, 0, 0);
-  } else {
-    const YV12_BUFFER_CONFIG *ref_frame_ptr = tpl_data->ref_frame[best_rf_idx];
+  int rate_cost = 0;
+  recon_error = 1;
 
-    InterPredParams inter_pred_params;
-    struct buf_2d ref_buf = { NULL, ref_frame_ptr->y_buffer,
-                              ref_frame_ptr->y_width, ref_frame_ptr->y_height,
-                              ref_frame_ptr->y_stride };
-    av1_init_inter_params(&inter_pred_params, bw, bh, mi_row * MI_SIZE,
-                          mi_col * MI_SIZE, 0, 0, xd->bd, is_cur_buf_hbd(xd), 0,
-                          &tpl_data->sf, &ref_buf, kernel);
-    inter_pred_params.conv_params = get_conv_params(0, 0, xd->bd);
+  for (int plane = 0; plane < MAX_MB_PLANE; ++plane) {
+    struct macroblockd_plane *pd = &xd->plane[plane];
 
-    av1_enc_build_one_inter_predictor(dst_buffer, dst_buffer_stride,
-                                      &best_mv.as_mv, &inter_pred_params);
+    BLOCK_SIZE bsize_plane =
+        ss_size_lookup[txsize_to_bsize[tx_size]][pd->subsampling_x]
+                      [pd->subsampling_y];
+
+    dst_buffer_stride = rec_stride_pool[plane];
+    dst_mb_offset = (mi_row * MI_SIZE * dst_buffer_stride + mi_col * MI_SIZE) >>
+                    pd->subsampling_x;
+    dst_buffer = rec_buffer_pool[plane] + dst_mb_offset;
+    if (!is_inter_mode(best_mode)) {
+      av1_predict_intra_block(
+          cm, xd, block_size_wide[bsize_plane], block_size_high[bsize_plane],
+          max_txsize_rect_lookup[bsize_plane], best_mode, 0, 0,
+          FILTER_INTRA_MODES, dst_buffer, dst_buffer_stride, dst_buffer,
+          dst_buffer_stride, 0, 0, plane);
+    } else {
+      const YV12_BUFFER_CONFIG *ref_frame_ptr =
+          tpl_data->ref_frame[best_rf_idx];
+
+      uint8_t *ref_buffer_pool[MAX_MB_PLANE] = {
+        ref_frame_ptr->y_buffer,
+        ref_frame_ptr->u_buffer,
+        ref_frame_ptr->v_buffer,
+      };
+      InterPredParams inter_pred_params;
+      struct buf_2d ref_buf = {
+        NULL, ref_buffer_pool[plane],
+        plane ? ref_frame_ptr->uv_width : ref_frame_ptr->y_width,
+        plane ? ref_frame_ptr->uv_height : ref_frame_ptr->y_height,
+        plane ? ref_frame_ptr->uv_stride : ref_frame_ptr->y_stride
+      };
+      av1_init_inter_params(
+          &inter_pred_params, bw >> pd->subsampling_x, bh >> pd->subsampling_y,
+          (mi_row * MI_SIZE) >> pd->subsampling_y,
+          (mi_col * MI_SIZE) >> pd->subsampling_x, pd->subsampling_x,
+          pd->subsampling_y, xd->bd, is_cur_buf_hbd(xd), 0, &tpl_data->sf,
+          &ref_buf, kernel);
+      inter_pred_params.conv_params = get_conv_params(0, plane, xd->bd);
+
+      av1_enc_build_one_inter_predictor(dst_buffer, dst_buffer_stride,
+                                        &best_mv.as_mv, &inter_pred_params);
+    }
+
+    src_stride = src_stride_pool[plane];
+    int src_mb_offset =
+        (mi_row * MI_SIZE * src_stride + mi_col * MI_SIZE) >> pd->subsampling_x;
+
+    int this_rate = 1;
+    int64_t this_recon_error = 1;
+    txfm_quant_rdcost(x, src_diff, bw >> pd->subsampling_x,
+                      src_buffer_pool[plane] + src_mb_offset, src_stride,
+                      dst_buffer, dst_buffer_stride, coeff, qcoeff, dqcoeff,
+                      bw >> pd->subsampling_x, bh >> pd->subsampling_x,
+                      max_txsize_rect_lookup[bsize_plane], &this_rate,
+                      &this_recon_error, &sse);
+
+    recon_error += this_recon_error;
+    rate_cost += this_rate;
   }
 
-  int rate_cost;
-  txfm_quant_rdcost(x, src_diff, bw, src_mb_buffer, src_stride, dst_buffer,
-                    dst_buffer_stride, coeff, qcoeff, dqcoeff, bw, bh, tx_size,
-                    &rate_cost, &recon_error, &sse);
-
   tpl_stats->recrf_dist = recon_error << (TPL_DEP_COST_SCALE_LOG2);
   tpl_stats->recrf_rate = rate_cost << TPL_DEP_COST_SCALE_LOG2;
   if (!is_inter_mode(best_mode)) {