Rework forward transform and quantization in tpl

Use 2D-DCT to replace the WHT transform in tpl model pipeline.

STATS_CHANGED

Change-Id: I536c45e2aeb3ae9cce5e2d14c111cbcb2f562bbe
diff --git a/av1/encoder/tpl_model.c b/av1/encoder/tpl_model.c
index bb11f1e..004211e 100644
--- a/av1/encoder/tpl_model.c
+++ b/av1/encoder/tpl_model.c
@@ -25,6 +25,7 @@
 
 #include "av1/encoder/encoder.h"
 #include "av1/encoder/encode_strategy.h"
+#include "av1/encoder/hybrid_fwd_txfm.h"
 #include "av1/encoder/rdopt.h"
 #include "av1/encoder/reconinter_enc.h"
 #include "av1/encoder/tpl_model.h"
@@ -52,32 +53,16 @@
 
 static AOM_INLINE void wht_fwd_txfm(int16_t *src_diff, int bw,
                                     tran_low_t *coeff, TX_SIZE tx_size,
-                                    int is_hbd) {
-#if CONFIG_AV1_HIGHBITDEPTH
-  if (is_hbd) {
-    switch (tx_size) {
-      case TX_8X8: aom_highbd_hadamard_8x8(src_diff, bw, coeff); break;
-      case TX_16X16: aom_highbd_hadamard_16x16(src_diff, bw, coeff); break;
-      case TX_32X32: aom_highbd_hadamard_32x32(src_diff, bw, coeff); break;
-      default: assert(0);
-    }
-  } else {
-    switch (tx_size) {
-      case TX_8X8: aom_hadamard_8x8(src_diff, bw, coeff); break;
-      case TX_16X16: aom_hadamard_16x16(src_diff, bw, coeff); break;
-      case TX_32X32: aom_hadamard_32x32(src_diff, bw, coeff); break;
-      default: assert(0);
-    }
-  }
-#else
-  (void)is_hbd;
-  switch (tx_size) {
-    case TX_8X8: aom_hadamard_8x8(src_diff, bw, coeff); break;
-    case TX_16X16: aom_hadamard_16x16(src_diff, bw, coeff); break;
-    case TX_32X32: aom_hadamard_32x32(src_diff, bw, coeff); break;
-    default: assert(0);
-  }
-#endif
+                                    int bit_depth, int is_hbd) {
+  TxfmParam txfm_param;
+  txfm_param.tx_type = DCT_DCT;
+  txfm_param.tx_size = tx_size;
+  txfm_param.lossless = 0;
+  txfm_param.tx_set_type = EXT_TX_SET_ALL16;
+
+  txfm_param.bd = bit_depth;
+  txfm_param.is_hbd = is_hbd;
+  av1_fwd_txfm(src_diff, coeff, bw, &txfm_param);
 }
 
 static uint32_t motion_estimation(AV1_COMP *cpi, MACROBLOCK *x,
@@ -138,22 +123,28 @@
 static AOM_INLINE void mode_estimation(
     AV1_COMP *cpi, MACROBLOCK *x, MACROBLOCKD *xd, struct scale_factors *sf,
     int frame_idx, int16_t *src_diff, tran_low_t *coeff, tran_low_t *qcoeff,
-    tran_low_t *dqcoeff, int use_satd, int mi_row, int mi_col, BLOCK_SIZE bsize,
+    tran_low_t *dqcoeff, int mi_row, int mi_col, BLOCK_SIZE bsize,
     TX_SIZE tx_size, const YV12_BUFFER_CONFIG *ref_frame[], uint8_t *predictor,
     int64_t *recon_error, int64_t *sse, TplDepStats *tpl_stats) {
   AV1_COMMON *cm = &cpi->common;
   const GF_GROUP *gf_group = &cpi->gf_group;
 
+  TplDepFrame *tpl_frame = &cpi->tpl_frame[frame_idx];
+
   const int bw = 4 << mi_size_wide_log2[bsize];
   const int bh = 4 << mi_size_high_log2[bsize];
   const int pix_num = bw * bh;
   const int_interpfilters kernel =
       av1_broadcast_interp_filter(EIGHTTAP_REGULAR);
 
+  (void)predictor;
+
   int64_t best_intra_cost = INT64_MAX;
   int64_t intra_cost;
   PREDICTION_MODE mode;
   int mb_y_offset = mi_row * MI_SIZE * xd->cur_buf->y_stride + mi_col * MI_SIZE;
+  uint8_t *dst_buffer = tpl_frame->rec_picture->y_buffer + mb_y_offset;
+  const int dst_buffer_stride = tpl_frame->rec_picture->y_stride;
 
   memset(tpl_stats, 0, sizeof(*tpl_stats));
 
@@ -165,9 +156,7 @@
   const int q_cur = gf_group->q_val[frame_idx];
   const int16_t qstep_cur =
       ROUND_POWER_OF_TWO(av1_ac_quant_QTX(q_cur, 0, xd->bd), xd->bd - 8);
-  const int qstep_cur_noise =
-      use_satd ? ((int)qstep_cur * pix_num + 16) / (4 * 8)
-               : ((int)qstep_cur * (int)qstep_cur * pix_num + 384) / (12 * 64);
+  const int qstep_cur_noise = ((int)qstep_cur * pix_num + 16) / (4 * 8);
 
   // Intra prediction search
   xd->mi[0]->ref_frame[0] = INTRA_FRAME;
@@ -178,45 +167,26 @@
     src = xd->cur_buf->y_buffer + mb_y_offset;
     src_stride = xd->cur_buf->y_stride;
 
-    dst = predictor;
-    dst_stride = bw;
+    dst = dst_buffer;
+    dst_stride = dst_buffer_stride;
 
     av1_predict_intra_block(
         cm, xd, block_size_wide[bsize], block_size_high[bsize], tx_size, mode,
         0, 0, FILTER_INTRA_MODES, src, src_stride, dst, dst_stride, 0, 0, 0);
 
-    if (use_satd) {
 #if CONFIG_AV1_HIGHBITDEPTH
-      if (is_cur_buf_hbd(xd)) {
-        aom_highbd_subtract_block(bh, bw, src_diff, bw, src, src_stride, dst,
-                                  dst_stride, xd->bd);
-      } else {
-        aom_subtract_block(bh, bw, src_diff, bw, src, src_stride, dst,
-                           dst_stride);
-      }
-#else
+    if (is_cur_buf_hbd(xd)) {
+      aom_highbd_subtract_block(bh, bw, src_diff, bw, src, src_stride, dst,
+                                dst_stride, xd->bd);
+    } else {
       aom_subtract_block(bh, bw, src_diff, bw, src, src_stride, dst,
                          dst_stride);
-#endif
-      wht_fwd_txfm(src_diff, bw, coeff, tx_size, is_cur_buf_hbd(xd));
-      intra_cost = aom_satd(coeff, pix_num);
-    } else {
-      int64_t intra_sse;
-#if CONFIG_AV1_HIGHBITDEPTH
-      if (is_cur_buf_hbd(xd)) {
-        intra_sse =
-            aom_highbd_sse(xd->cur_buf->y_buffer + mb_y_offset,
-                           xd->cur_buf->y_stride, predictor, bw, bw, bh);
-      } else {
-        intra_sse = aom_sse(xd->cur_buf->y_buffer + mb_y_offset,
-                            xd->cur_buf->y_stride, predictor, bw, bw, bh);
-      }
-#else
-      intra_sse = aom_sse(xd->cur_buf->y_buffer + mb_y_offset,
-                          xd->cur_buf->y_stride, predictor, bw, bw, bh);
-#endif
-      intra_cost = ROUND_POWER_OF_TWO(intra_sse, (xd->bd - 8) * 2);
     }
+#else
+    aom_subtract_block(bh, bw, src_diff, bw, src, src_stride, dst, dst_stride);
+#endif
+    wht_fwd_txfm(src_diff, bw, coeff, tx_size, xd->bd, is_cur_buf_hbd(xd));
+    intra_cost = aom_satd(coeff, pix_num);
     intra_cost += qstep_cur_noise;
 
     if (intra_cost < best_intra_cost) best_intra_cost = intra_cost;
@@ -245,10 +215,7 @@
 
     const int16_t qstep_ref =
         ROUND_POWER_OF_TWO(av1_ac_quant_QTX(q_ref, 0, xd->bd), xd->bd - 8);
-    const int qstep_ref_noise =
-        use_satd
-            ? ((int)qstep_ref * pix_num + 16) / (4 * 8)
-            : ((int)qstep_ref * (int)qstep_ref * pix_num + 384) / (12 * 64);
+    const int qstep_ref_noise = ((int)qstep_ref * pix_num + 16) / (4 * 8);
     int mb_y_offset_ref =
         mi_row * MI_SIZE * ref_frame[rf_idx]->y_stride + mi_col * MI_SIZE;
 
@@ -261,48 +228,31 @@
     WarpTypesAllowed warp_types;
     memset(&warp_types, 0, sizeof(WarpTypesAllowed));
 
-    av1_build_inter_predictor(ref_frame[rf_idx]->y_buffer + mb_y_offset_ref,
-                              ref_frame[rf_idx]->y_stride, &predictor[0], bw,
-                              &x->best_mv.as_mv, sf, bw, bh, &conv_params,
-                              kernel, &warp_types, mi_col * MI_SIZE,
-                              mi_row * MI_SIZE, 0, 0, MV_PRECISION_Q3,
-                              mi_col * MI_SIZE, mi_row * MI_SIZE, xd, 0);
-    if (use_satd) {
+    av1_build_inter_predictor(
+        ref_frame[rf_idx]->y_buffer + mb_y_offset_ref,
+        ref_frame[rf_idx]->y_stride, dst_buffer, dst_buffer_stride,
+        &x->best_mv.as_mv, sf, bw, bh, &conv_params, kernel, &warp_types,
+        mi_col * MI_SIZE, mi_row * MI_SIZE, 0, 0, MV_PRECISION_Q3,
+        mi_col * MI_SIZE, mi_row * MI_SIZE, xd, 0);
+
 #if CONFIG_AV1_HIGHBITDEPTH
-      if (is_cur_buf_hbd(xd)) {
-        aom_highbd_subtract_block(bh, bw, src_diff, bw,
-                                  xd->cur_buf->y_buffer + mb_y_offset,
-                                  xd->cur_buf->y_stride, predictor, bw, xd->bd);
-      } else {
-        aom_subtract_block(bh, bw, src_diff, bw,
-                           xd->cur_buf->y_buffer + mb_y_offset,
-                           xd->cur_buf->y_stride, predictor, bw);
-      }
-#else
+    if (is_cur_buf_hbd(xd)) {
+      aom_highbd_subtract_block(
+          bh, bw, src_diff, bw, xd->cur_buf->y_buffer + mb_y_offset,
+          xd->cur_buf->y_stride, dst_buffer, dst_buffer_stride, xd->bd);
+    } else {
       aom_subtract_block(bh, bw, src_diff, bw,
                          xd->cur_buf->y_buffer + mb_y_offset,
-                         xd->cur_buf->y_stride, predictor, bw);
-#endif
-      wht_fwd_txfm(src_diff, bw, coeff, tx_size, is_cur_buf_hbd(xd));
-
-      inter_cost = aom_satd(coeff, pix_num);
-    } else {
-      int64_t inter_sse;
-#if CONFIG_AV1_HIGHBITDEPTH
-      if (is_cur_buf_hbd(xd)) {
-        inter_sse =
-            aom_highbd_sse(xd->cur_buf->y_buffer + mb_y_offset,
-                           xd->cur_buf->y_stride, predictor, bw, bw, bh);
-      } else {
-        inter_sse = aom_sse(xd->cur_buf->y_buffer + mb_y_offset,
-                            xd->cur_buf->y_stride, predictor, bw, bw, bh);
-      }
-#else
-      inter_sse = aom_sse(xd->cur_buf->y_buffer + mb_y_offset,
-                          xd->cur_buf->y_stride, predictor, bw, bw, bh);
-#endif
-      inter_cost = ROUND_POWER_OF_TWO(inter_sse, (xd->bd - 8) * 2);
+                         xd->cur_buf->y_stride, dst_buffer, dst_buffer_stride);
     }
+#else
+    aom_subtract_block(bh, bw, src_diff, bw,
+                       xd->cur_buf->y_buffer + mb_y_offset,
+                       xd->cur_buf->y_stride, dst_buffer, dst_buffer_stride);
+#endif
+
+    wht_fwd_txfm(src_diff, bw, coeff, tx_size, xd->bd, is_cur_buf_hbd(xd));
+    inter_cost = aom_satd(coeff, pix_num);
     inter_cost_weighted = inter_cost + qstep_ref_noise;
 
     if (inter_cost_weighted < best_inter_cost_weighted) {
@@ -616,7 +566,7 @@
       xd->mb_to_left_edge = -((mi_col * MI_SIZE) * 8);
       xd->mb_to_right_edge = ((cm->mi_cols - mi_width - mi_col) * MI_SIZE) * 8;
       mode_estimation(cpi, x, xd, &sf, frame_idx, src_diff, coeff, qcoeff,
-                      dqcoeff, 1, mi_row, mi_col, bsize, tx_size, ref_frame,
+                      dqcoeff, mi_row, mi_col, bsize, tx_size, ref_frame,
                       predictor, &recon_error, &sse, &tpl_stats);
 
       // Motion flow dependency dispenser.
@@ -959,7 +909,8 @@
           aom_subtract_block(bh, bw, src_diff, bw, src_buf, src_stride, dst_buf,
                              dst_stride);
 #endif
-          wht_fwd_txfm(src_diff, bw, coeff, tx_size, is_cur_buf_hbd(xd));
+          wht_fwd_txfm(src_diff, bw, coeff, tx_size, xd->bd,
+                       is_cur_buf_hbd(xd));
 
           intra_cost = aom_satd(coeff, pix_num);
         } else {
@@ -1010,7 +961,7 @@
         aom_subtract_block(bh, bw, src_diff, bw, src->y_buffer + mb_y_offset,
                            src->y_stride, predictor, bw);
 #endif
-        wht_fwd_txfm(src_diff, bw, coeff, tx_size, is_cur_buf_hbd(xd));
+        wht_fwd_txfm(src_diff, bw, coeff, tx_size, xd->bd, is_cur_buf_hbd(xd));
         inter_cost = aom_satd(coeff, pix_num);
       } else {
         int64_t sse;