Refactor transform coding path in tpl

Consolidate the transform coding and rate-distortion estimation
operations for tpl model building.

Change-Id: I0bb0dfe912f11bc024ac22bcb633581e454dfa93
diff --git a/av1/encoder/encodemb.c b/av1/encoder/encodemb.c
index e528842..fb4fab5 100644
--- a/av1/encoder/encodemb.c
+++ b/av1/encoder/encodemb.c
@@ -35,10 +35,10 @@
 #include "av1/encoder/rd.h"
 #include "av1/encoder/rdopt.h"
 
-static void subtract_block(const MACROBLOCKD *xd, int rows, int cols,
-                           int16_t *diff, ptrdiff_t diff_stride,
-                           const uint8_t *src8, ptrdiff_t src_stride,
-                           const uint8_t *pred8, ptrdiff_t pred_stride) {
+void av1_subtract_block(const MACROBLOCKD *xd, int rows, int cols,
+                        int16_t *diff, ptrdiff_t diff_stride,
+                        const uint8_t *src8, ptrdiff_t src_stride,
+                        const uint8_t *pred8, ptrdiff_t pred_stride) {
   assert(rows >= 4 && cols >= 4);
 #if CONFIG_AV1_HIGHBITDEPTH
   if (is_cur_buf_hbd(xd)) {
@@ -68,8 +68,8 @@
       &p->src.buf[(blk_row * src_stride + blk_col) << tx_size_wide_log2[0]];
   int16_t *src_diff =
       &p->src_diff[(blk_row * diff_stride + blk_col) << tx_size_wide_log2[0]];
-  subtract_block(xd, tx1d_height, tx1d_width, src_diff, diff_stride, src,
-                 src_stride, dst, dst_stride);
+  av1_subtract_block(xd, tx1d_height, tx1d_width, src_diff, diff_stride, src,
+                     src_stride, dst, dst_stride);
 }
 
 void av1_subtract_plane(MACROBLOCK *x, BLOCK_SIZE bsize, int plane) {
@@ -82,8 +82,8 @@
   const int bh = block_size_high[plane_bsize];
   const MACROBLOCKD *xd = &x->e_mbd;
 
-  subtract_block(xd, bh, bw, p->src_diff, bw, p->src.buf, p->src.stride,
-                 pd->dst.buf, pd->dst.stride);
+  av1_subtract_block(xd, bh, bw, p->src_diff, bw, p->src.buf, p->src.stride,
+                     pd->dst.buf, pd->dst.stride);
 }
 
 int av1_optimize_b(const struct AV1_COMP *cpi, MACROBLOCK *mb, int plane,
diff --git a/av1/encoder/encodemb.h b/av1/encoder/encodemb.h
index d4394cf..efb2314 100644
--- a/av1/encoder/encodemb.h
+++ b/av1/encoder/encodemb.h
@@ -68,6 +68,11 @@
                    int block, TX_SIZE tx_size, TX_TYPE tx_type,
                    const TXB_CTX *const txb_ctx, int fast_mode, int *rate_cost);
 
+void av1_subtract_block(const MACROBLOCKD *xd, int rows, int cols,
+                        int16_t *diff, ptrdiff_t diff_stride,
+                        const uint8_t *src8, ptrdiff_t src_stride,
+                        const uint8_t *pred8, ptrdiff_t pred_stride);
+
 void av1_subtract_txb(MACROBLOCK *x, int plane, BLOCK_SIZE plane_bsize,
                       int blk_col, int blk_row, TX_SIZE tx_size);
 
diff --git a/av1/encoder/tpl_model.c b/av1/encoder/tpl_model.c
index aaef6de..abfa949 100644
--- a/av1/encoder/tpl_model.c
+++ b/av1/encoder/tpl_model.c
@@ -81,6 +81,28 @@
   return (rate_cost << AV1_PROB_COST_SHIFT);
 }
 
+static void txfm_quant_rdcost(MACROBLOCK *x, int16_t *src_diff, int diff_stride,
+                              uint8_t *src, int src_stride, uint8_t *dst,
+                              int dst_stride, tran_low_t *coeff,
+                              tran_low_t *qcoeff, tran_low_t *dqcoeff, int bw,
+                              int bh, TX_SIZE tx_size, int *rate_cost,
+                              int64_t *recon_error, int64_t *sse) {
+  const MACROBLOCKD *xd = &x->e_mbd;
+  uint16_t eob;
+  av1_subtract_block(xd, bh, bw, src_diff, diff_stride, src, src_stride, dst,
+                     dst_stride);
+  wht_fwd_txfm(src_diff, diff_stride, coeff, tx_size, xd->bd,
+               is_cur_buf_hbd(xd));
+
+  get_quantize_error(x, 0, coeff, qcoeff, dqcoeff, tx_size, &eob, recon_error,
+                     sse);
+
+  *rate_cost = rate_estimator(qcoeff, eob, tx_size);
+
+  av1_inverse_transform_block(xd, dqcoeff, 0, DCT_DCT, tx_size, dst, dst_stride,
+                              eob, 0);
+}
+
 static uint32_t motion_estimation(AV1_COMP *cpi, MACROBLOCK *x,
                                   uint8_t *cur_frame_buf,
                                   uint8_t *ref_frame_buf, int stride,
@@ -218,17 +240,8 @@
                             FILTER_INTRA_MODES, dst_buffer, dst_buffer_stride,
                             dst, dst_stride, 0, 0, 0);
 
-#if CONFIG_AV1_HIGHBITDEPTH
-    if (is_cur_buf_hbd(xd)) {
-      aom_highbd_subtract_block(bh, bw, src_diff, bw, src, src_stride, dst,
-                                dst_stride, xd->bd);
-    } else {
-      aom_subtract_block(bh, bw, src_diff, bw, src, src_stride, dst,
-                         dst_stride);
-    }
-#else
-    aom_subtract_block(bh, bw, src_diff, bw, src, src_stride, dst, dst_stride);
-#endif
+    av1_subtract_block(xd, bh, bw, src_diff, bw, src, src_stride, dst,
+                       dst_stride);
     wht_fwd_txfm(src_diff, bw, coeff, tx_size, xd->bd, is_cur_buf_hbd(xd));
     intra_cost = aom_satd(coeff, pix_num);
 
@@ -253,6 +266,8 @@
     if (ref_frame[rf_idx] == NULL) continue;
     if (src_ref_frame[rf_idx] == NULL) continue;
 
+    int rate_cost;
+    int64_t distortion, tsse;
     const YV12_BUFFER_CONFIG *ref_frame_ptr = src_ref_frame[rf_idx];
     int ref_mb_offset =
         mi_row * MI_SIZE * ref_frame_ptr->y_stride + mi_col * MI_SIZE;
@@ -274,32 +289,19 @@
     av1_build_inter_predictor(predictor, bw, &x->best_mv.as_mv,
                               &inter_pred_params);
 
-#if CONFIG_AV1_HIGHBITDEPTH
-    if (is_cur_buf_hbd(xd)) {
-      aom_highbd_subtract_block(bh, bw, src_diff, bw, src_mb_buffer, src_stride,
-                                predictor, bw, xd->bd);
-    } else {
-      aom_subtract_block(bh, bw, src_diff, bw, src_mb_buffer, src_stride,
-                         predictor, bw);
-    }
-#else
-    aom_subtract_block(bh, bw, src_diff, bw, src_mb_buffer, src_stride,
-                       predictor, bw);
-#endif
+    txfm_quant_rdcost(x, src_diff, bw, src_mb_buffer, src_stride, predictor, bw,
+                      coeff, qcoeff, dqcoeff, bw, bh, tx_size, &rate_cost,
+                      &distortion, &tsse);
 
-    wht_fwd_txfm(src_diff, bw, coeff, tx_size, xd->bd, is_cur_buf_hbd(xd));
     inter_cost = aom_satd(coeff, pix_num);
 
     if (inter_cost < best_inter_cost) {
-      uint16_t eob;
       best_rf_idx = rf_idx;
       best_inter_cost = inter_cost;
       best_mv.as_int = x->best_mv.as_int;
-      get_quantize_error(x, 0, coeff, qcoeff, dqcoeff, tx_size, &eob,
-                         recon_error, sse);
-      int rate_cost = rate_estimator(qcoeff, eob, tx_size);
       tpl_stats->srcrf_rate = rate_cost << TPL_DEP_COST_SCALE_LOG2;
-
+      *recon_error = distortion;
+      *sse = tsse;
       if (best_inter_cost < best_intra_cost) best_mode = NEWMV;
     }
   }
@@ -335,29 +337,10 @@
                             dst_buffer, dst_buffer_stride, 0, 0, 0);
   }
 
-#if CONFIG_AV1_HIGHBITDEPTH
-  if (is_cur_buf_hbd(xd)) {
-    aom_highbd_subtract_block(bh, bw, src_diff, bw, src_mb_buffer, src_stride,
-                              dst_buffer, dst_buffer_stride, xd->bd);
-  } else {
-    aom_subtract_block(bh, bw, src_diff, bw, src_mb_buffer, src_stride,
-                       dst_buffer, dst_buffer_stride);
-  }
-#else
-  aom_subtract_block(bh, bw, src_diff, bw, src_mb_buffer, src_stride,
-                     dst_buffer, dst_buffer_stride);
-#endif
-  wht_fwd_txfm(src_diff, bw, coeff, tx_size, xd->bd, is_cur_buf_hbd(xd));
-
-  uint16_t eob;
-
-  get_quantize_error(x, 0, coeff, qcoeff, dqcoeff, tx_size, &eob, recon_error,
-                     sse);
-
-  int rate_cost = rate_estimator(qcoeff, eob, tx_size);
-
-  av1_inverse_transform_block(xd, dqcoeff, 0, DCT_DCT, tx_size, dst_buffer,
-                              dst_buffer_stride, eob, 0);
+  int rate_cost;
+  txfm_quant_rdcost(x, src_diff, bw, src_mb_buffer, src_stride, dst_buffer,
+                    dst_buffer_stride, coeff, qcoeff, dqcoeff, bw, bh, tx_size,
+                    &rate_cost, recon_error, sse);
 
   tpl_stats->recrf_dist = *recon_error << (TPL_DEP_COST_SCALE_LOG2);
   tpl_stats->recrf_rate = rate_cost << TPL_DEP_COST_SCALE_LOG2;