Refactor transform coding path in tpl
Consolidate the transform coding and rate-distortion estimation
operations for tpl model building.
Change-Id: I0bb0dfe912f11bc024ac22bcb633581e454dfa93
diff --git a/av1/encoder/encodemb.c b/av1/encoder/encodemb.c
index e528842..fb4fab5 100644
--- a/av1/encoder/encodemb.c
+++ b/av1/encoder/encodemb.c
@@ -35,10 +35,10 @@
#include "av1/encoder/rd.h"
#include "av1/encoder/rdopt.h"
-static void subtract_block(const MACROBLOCKD *xd, int rows, int cols,
- int16_t *diff, ptrdiff_t diff_stride,
- const uint8_t *src8, ptrdiff_t src_stride,
- const uint8_t *pred8, ptrdiff_t pred_stride) {
+void av1_subtract_block(const MACROBLOCKD *xd, int rows, int cols,
+ int16_t *diff, ptrdiff_t diff_stride,
+ const uint8_t *src8, ptrdiff_t src_stride,
+ const uint8_t *pred8, ptrdiff_t pred_stride) {
assert(rows >= 4 && cols >= 4);
#if CONFIG_AV1_HIGHBITDEPTH
if (is_cur_buf_hbd(xd)) {
@@ -68,8 +68,8 @@
&p->src.buf[(blk_row * src_stride + blk_col) << tx_size_wide_log2[0]];
int16_t *src_diff =
&p->src_diff[(blk_row * diff_stride + blk_col) << tx_size_wide_log2[0]];
- subtract_block(xd, tx1d_height, tx1d_width, src_diff, diff_stride, src,
- src_stride, dst, dst_stride);
+ av1_subtract_block(xd, tx1d_height, tx1d_width, src_diff, diff_stride, src,
+ src_stride, dst, dst_stride);
}
void av1_subtract_plane(MACROBLOCK *x, BLOCK_SIZE bsize, int plane) {
@@ -82,8 +82,8 @@
const int bh = block_size_high[plane_bsize];
const MACROBLOCKD *xd = &x->e_mbd;
- subtract_block(xd, bh, bw, p->src_diff, bw, p->src.buf, p->src.stride,
- pd->dst.buf, pd->dst.stride);
+ av1_subtract_block(xd, bh, bw, p->src_diff, bw, p->src.buf, p->src.stride,
+ pd->dst.buf, pd->dst.stride);
}
int av1_optimize_b(const struct AV1_COMP *cpi, MACROBLOCK *mb, int plane,
diff --git a/av1/encoder/encodemb.h b/av1/encoder/encodemb.h
index d4394cf..efb2314 100644
--- a/av1/encoder/encodemb.h
+++ b/av1/encoder/encodemb.h
@@ -68,6 +68,11 @@
int block, TX_SIZE tx_size, TX_TYPE tx_type,
const TXB_CTX *const txb_ctx, int fast_mode, int *rate_cost);
+void av1_subtract_block(const MACROBLOCKD *xd, int rows, int cols,
+ int16_t *diff, ptrdiff_t diff_stride,
+ const uint8_t *src8, ptrdiff_t src_stride,
+ const uint8_t *pred8, ptrdiff_t pred_stride);
+
void av1_subtract_txb(MACROBLOCK *x, int plane, BLOCK_SIZE plane_bsize,
int blk_col, int blk_row, TX_SIZE tx_size);
diff --git a/av1/encoder/tpl_model.c b/av1/encoder/tpl_model.c
index aaef6de..abfa949 100644
--- a/av1/encoder/tpl_model.c
+++ b/av1/encoder/tpl_model.c
@@ -81,6 +81,28 @@
return (rate_cost << AV1_PROB_COST_SHIFT);
}
+static void txfm_quant_rdcost(MACROBLOCK *x, int16_t *src_diff, int diff_stride,
+ uint8_t *src, int src_stride, uint8_t *dst,
+ int dst_stride, tran_low_t *coeff,
+ tran_low_t *qcoeff, tran_low_t *dqcoeff, int bw,
+ int bh, TX_SIZE tx_size, int *rate_cost,
+ int64_t *recon_error, int64_t *sse) {
+ const MACROBLOCKD *xd = &x->e_mbd;
+ uint16_t eob;
+ av1_subtract_block(xd, bh, bw, src_diff, diff_stride, src, src_stride, dst,
+ dst_stride);
+ wht_fwd_txfm(src_diff, diff_stride, coeff, tx_size, xd->bd,
+ is_cur_buf_hbd(xd));
+
+ get_quantize_error(x, 0, coeff, qcoeff, dqcoeff, tx_size, &eob, recon_error,
+ sse);
+
+ *rate_cost = rate_estimator(qcoeff, eob, tx_size);
+
+ av1_inverse_transform_block(xd, dqcoeff, 0, DCT_DCT, tx_size, dst, dst_stride,
+ eob, 0);
+}
+
static uint32_t motion_estimation(AV1_COMP *cpi, MACROBLOCK *x,
uint8_t *cur_frame_buf,
uint8_t *ref_frame_buf, int stride,
@@ -218,17 +240,8 @@
FILTER_INTRA_MODES, dst_buffer, dst_buffer_stride,
dst, dst_stride, 0, 0, 0);
-#if CONFIG_AV1_HIGHBITDEPTH
- if (is_cur_buf_hbd(xd)) {
- aom_highbd_subtract_block(bh, bw, src_diff, bw, src, src_stride, dst,
- dst_stride, xd->bd);
- } else {
- aom_subtract_block(bh, bw, src_diff, bw, src, src_stride, dst,
- dst_stride);
- }
-#else
- aom_subtract_block(bh, bw, src_diff, bw, src, src_stride, dst, dst_stride);
-#endif
+ av1_subtract_block(xd, bh, bw, src_diff, bw, src, src_stride, dst,
+ dst_stride);
wht_fwd_txfm(src_diff, bw, coeff, tx_size, xd->bd, is_cur_buf_hbd(xd));
intra_cost = aom_satd(coeff, pix_num);
@@ -253,6 +266,8 @@
if (ref_frame[rf_idx] == NULL) continue;
if (src_ref_frame[rf_idx] == NULL) continue;
+ int rate_cost;
+ int64_t distortion, tsse;
const YV12_BUFFER_CONFIG *ref_frame_ptr = src_ref_frame[rf_idx];
int ref_mb_offset =
mi_row * MI_SIZE * ref_frame_ptr->y_stride + mi_col * MI_SIZE;
@@ -274,32 +289,19 @@
av1_build_inter_predictor(predictor, bw, &x->best_mv.as_mv,
&inter_pred_params);
-#if CONFIG_AV1_HIGHBITDEPTH
- if (is_cur_buf_hbd(xd)) {
- aom_highbd_subtract_block(bh, bw, src_diff, bw, src_mb_buffer, src_stride,
- predictor, bw, xd->bd);
- } else {
- aom_subtract_block(bh, bw, src_diff, bw, src_mb_buffer, src_stride,
- predictor, bw);
- }
-#else
- aom_subtract_block(bh, bw, src_diff, bw, src_mb_buffer, src_stride,
- predictor, bw);
-#endif
+ txfm_quant_rdcost(x, src_diff, bw, src_mb_buffer, src_stride, predictor, bw,
+ coeff, qcoeff, dqcoeff, bw, bh, tx_size, &rate_cost,
+ &distortion, &tsse);
- wht_fwd_txfm(src_diff, bw, coeff, tx_size, xd->bd, is_cur_buf_hbd(xd));
inter_cost = aom_satd(coeff, pix_num);
if (inter_cost < best_inter_cost) {
- uint16_t eob;
best_rf_idx = rf_idx;
best_inter_cost = inter_cost;
best_mv.as_int = x->best_mv.as_int;
- get_quantize_error(x, 0, coeff, qcoeff, dqcoeff, tx_size, &eob,
- recon_error, sse);
- int rate_cost = rate_estimator(qcoeff, eob, tx_size);
tpl_stats->srcrf_rate = rate_cost << TPL_DEP_COST_SCALE_LOG2;
-
+ *recon_error = distortion;
+ *sse = tsse;
if (best_inter_cost < best_intra_cost) best_mode = NEWMV;
}
}
@@ -335,29 +337,10 @@
dst_buffer, dst_buffer_stride, 0, 0, 0);
}
-#if CONFIG_AV1_HIGHBITDEPTH
- if (is_cur_buf_hbd(xd)) {
- aom_highbd_subtract_block(bh, bw, src_diff, bw, src_mb_buffer, src_stride,
- dst_buffer, dst_buffer_stride, xd->bd);
- } else {
- aom_subtract_block(bh, bw, src_diff, bw, src_mb_buffer, src_stride,
- dst_buffer, dst_buffer_stride);
- }
-#else
- aom_subtract_block(bh, bw, src_diff, bw, src_mb_buffer, src_stride,
- dst_buffer, dst_buffer_stride);
-#endif
- wht_fwd_txfm(src_diff, bw, coeff, tx_size, xd->bd, is_cur_buf_hbd(xd));
-
- uint16_t eob;
-
- get_quantize_error(x, 0, coeff, qcoeff, dqcoeff, tx_size, &eob, recon_error,
- sse);
-
- int rate_cost = rate_estimator(qcoeff, eob, tx_size);
-
- av1_inverse_transform_block(xd, dqcoeff, 0, DCT_DCT, tx_size, dst_buffer,
- dst_buffer_stride, eob, 0);
+ int rate_cost;
+ txfm_quant_rdcost(x, src_diff, bw, src_mb_buffer, src_stride, dst_buffer,
+ dst_buffer_stride, coeff, qcoeff, dqcoeff, bw, bh, tx_size,
+ &rate_cost, recon_error, sse);
tpl_stats->recrf_dist = *recon_error << (TPL_DEP_COST_SCALE_LOG2);
tpl_stats->recrf_rate = rate_cost << TPL_DEP_COST_SCALE_LOG2;