ncobmc-adapt-weight: add interpolation mode search functions
Change-Id: I5370e38f6fe00f467e1945bc46866adea9422b22
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 96a496d..6220cfd 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -8439,6 +8439,9 @@
#if CONFIG_GLOBAL_MOTION
0, xd->global_motion,
#endif // CONFIG_GLOBAL_MOTION
+#if CONFIG_WARPED_MOTION
+ xd,
+#endif
mi);
#else
last_motion_mode_allowed = motion_mode_allowed(
@@ -12587,4 +12590,261 @@
}
}
#endif // CONFIG_NCOBMC
+
+#if CONFIG_NCOBMC_ADAPT_WEIGHT
+void av1_check_ncobmc_adapt_weight_rd(const struct AV1_COMP *cpi,
+ struct macroblock *x, int mi_row,
+ int mi_col) {
+ const AV1_COMMON *const cm = &cpi->common;
+ MACROBLOCKD *const xd = &x->e_mbd;
+ MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
+ BLOCK_SIZE bsize = mbmi->sb_type;
+#if CONFIG_VAR_TX
+ const int n4 = bsize_to_num_blk(bsize);
+ uint8_t backup_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE * 8];
+#endif
+ MB_MODE_INFO backup_mbmi;
+ int plane, ref, skip_blk, backup_skip;
+ RD_STATS rd_stats_y, rd_stats_uv, rd_stats_y2, rd_stats_uv2;
+ int rate_skip0 = av1_cost_bit(av1_get_skip_prob(cm, xd), 0);
+ int rate_skip1 = av1_cost_bit(av1_get_skip_prob(cm, xd), 1);
+ int64_t prev_rd, naw_rd; // ncobmc_adapt_weight_rd
+
+ // Recompute the rd for the motion mode decided in rd loop
+ if (mbmi->motion_mode == SIMPLE_TRANSLATION ||
+ mbmi->motion_mode == OBMC_CAUSAL) {
+ set_ref_ptrs(cm, xd, mbmi->ref_frame[0], mbmi->ref_frame[1]);
+ for (ref = 0; ref < 1 + has_second_ref(mbmi); ++ref) {
+ YV12_BUFFER_CONFIG *cfg = get_ref_frame_buffer(cpi, mbmi->ref_frame[ref]);
+ assert(cfg != NULL);
+ av1_setup_pre_planes(xd, ref, cfg, mi_row, mi_col,
+ &xd->block_refs[ref]->sf);
+ }
+ av1_setup_dst_planes(xd->plane, bsize, get_frame_new_buffer(cm), mi_row,
+ mi_col);
+
+ av1_build_inter_predictors_sb(cm, xd, mi_row, mi_col, NULL, bsize);
+ if (mbmi->motion_mode == OBMC_CAUSAL) {
+ av1_build_obmc_inter_predictors_sb(cm, xd, mi_row, mi_col);
+ }
+
+ av1_subtract_plane(x, bsize, 0);
+
+#if CONFIG_VAR_TX
+ if (cm->tx_mode == TX_MODE_SELECT && !xd->lossless[mbmi->segment_id]) {
+ select_tx_type_yrd(cpi, x, &rd_stats_y2, bsize, INT64_MAX);
+ } else {
+ int idx, idy;
+ super_block_yrd(cpi, x, &rd_stats_y2, bsize, INT64_MAX);
+ for (idy = 0; idy < xd->n8_h; ++idy)
+ for (idx = 0; idx < xd->n8_w; ++idx)
+ mbmi->inter_tx_size[idy][idx] = mbmi->tx_size;
+ memset(x->blk_skip[0], rd_stats_y2.skip,
+ sizeof(uint8_t) * xd->n8_h * xd->n8_w * 4);
+ }
+ inter_block_uvrd(cpi, x, &rd_stats_uv2, bsize, INT64_MAX);
+#else
+ super_block_yrd(cpi, x, &rd_stats_y2, bsize, INT64_MAX);
+ super_block_uvrd(cpi, x, &rd_stats_uv2, bsize, INT64_MAX);
+#endif
+ }
+
+ if (rd_stats_y2.skip && rd_stats_uv2.skip) {
+ rd_stats_y2.rate = rate_skip1;
+ rd_stats_uv2.rate = 0;
+ rd_stats_y2.dist = rd_stats_y2.sse;
+ rd_stats_uv2.dist = rd_stats_uv2.sse;
+ skip_blk = 1;
+ } else if (RDCOST(x->rdmult,
+ (rd_stats_y2.rate + rd_stats_uv2.rate + rate_skip0),
+ (rd_stats_y2.dist + rd_stats_uv2.dist)) >
+ RDCOST(x->rdmult, rate_skip1,
+ (rd_stats_y2.sse + rd_stats_uv2.sse))) {
+ rd_stats_y2.rate = rate_skip1;
+ rd_stats_uv2.rate = 0;
+ rd_stats_y2.dist = rd_stats_y2.sse;
+ rd_stats_uv2.dist = rd_stats_uv2.sse;
+ skip_blk = 1;
+ } else {
+ rd_stats_y2.rate += rate_skip0;
+ skip_blk = 0;
+ }
+
+ backup_mbmi = *mbmi;
+ backup_skip = skip_blk;
+#if CONFIG_VAR_TX
+ memcpy(backup_blk_skip, x->blk_skip[0], sizeof(backup_blk_skip[0]) * n4);
+#endif
+ prev_rd = RDCOST(x->rdmult, (rd_stats_y2.rate + rd_stats_uv2.rate),
+ (rd_stats_y2.dist + rd_stats_uv2.dist));
+ prev_rd +=
+ RDCOST(x->rdmult, x->motion_mode_cost[bsize][mbmi->motion_mode], 0);
+
+ // Compute the rd cost for ncobmc adaptive weight
+ mbmi->motion_mode = NCOBMC_ADAPT_WEIGHT;
+
+ for (plane = 0; plane < MAX_MB_PLANE; ++plane) {
+ get_pred_from_intrpl_buf(xd, mi_row, mi_col, bsize, plane);
+ }
+
+ av1_subtract_plane(x, bsize, 0);
+
+#if CONFIG_VAR_TX
+ if (cm->tx_mode == TX_MODE_SELECT && !xd->lossless[mbmi->segment_id]) {
+ select_tx_type_yrd(cpi, x, &rd_stats_y, bsize, INT64_MAX);
+ } else {
+ int idx, idy;
+ super_block_yrd(cpi, x, &rd_stats_y, bsize, INT64_MAX);
+ for (idy = 0; idy < xd->n8_h; ++idy)
+ for (idx = 0; idx < xd->n8_w; ++idx)
+ mbmi->inter_tx_size[idy][idx] = mbmi->tx_size;
+ memset(x->blk_skip[0], rd_stats_y.skip,
+ sizeof(uint8_t) * xd->n8_h * xd->n8_w * 4);
+ }
+ inter_block_uvrd(cpi, x, &rd_stats_uv, bsize, INT64_MAX);
+#else
+ super_block_yrd(cpi, x, &rd_stats_y, bsize, INT64_MAX);
+ super_block_uvrd(cpi, x, &rd_stats_uv, bsize, INT64_MAX);
+#endif
+ assert(rd_stats_y.rate != INT_MAX && rd_stats_uv.rate != INT_MAX);
+
+ if (rd_stats_y.skip && rd_stats_uv.skip) {
+ rd_stats_y.rate = rate_skip1;
+ rd_stats_uv.rate = 0;
+ rd_stats_y.dist = rd_stats_y.sse;
+ rd_stats_uv.dist = rd_stats_uv.sse;
+ skip_blk = 1;
+ } else if (RDCOST(x->rdmult,
+ (rd_stats_y.rate + rd_stats_uv.rate + rate_skip0),
+ (rd_stats_y.dist + rd_stats_uv.dist)) >
+ RDCOST(x->rdmult, rate_skip1,
+ (rd_stats_y.sse + rd_stats_uv.sse))) {
+ rd_stats_y.rate = rate_skip1;
+ rd_stats_uv.rate = 0;
+ rd_stats_y.dist = rd_stats_y.sse;
+ rd_stats_uv.dist = rd_stats_uv.sse;
+ skip_blk = 1;
+ } else {
+ rd_stats_y.rate += rate_skip0;
+ skip_blk = 0;
+ }
+ naw_rd = RDCOST(x->rdmult, (rd_stats_y.rate + rd_stats_uv.rate),
+ (rd_stats_y.dist + rd_stats_uv.dist));
+ naw_rd += RDCOST(x->rdmult, x->motion_mode_cost[bsize][mbmi->motion_mode], 0);
+
+ // Calculate the ncobmc mode costs
+ {
+ ADAPT_OVERLAP_BLOCK aob = adapt_overlap_block_lookup[bsize];
+ naw_rd +=
+ RDCOST(x->rdmult, x->ncobmc_mode_cost[aob][mbmi->ncobmc_mode[0]], 0);
+ if (mi_size_wide[bsize] != mi_size_high[bsize])
+ naw_rd +=
+ RDCOST(x->rdmult, x->ncobmc_mode_cost[aob][mbmi->ncobmc_mode[1]], 0);
+ }
+
+ if (prev_rd > naw_rd) {
+ x->skip = skip_blk;
+ } else {
+ *mbmi = backup_mbmi;
+ x->skip = backup_skip;
+#if CONFIG_VAR_TX
+ memcpy(x->blk_skip[0], backup_blk_skip, sizeof(backup_blk_skip[0]) * n4);
+#endif
+ }
+}
+
+int64_t get_ncobmc_error(MACROBLOCKD *xd, int pxl_row, int pxl_col,
+ BLOCK_SIZE bsize, int plane, struct buf_2d *src) {
+ const int wide = AOMMIN(mi_size_wide[bsize] * MI_SIZE,
+ (xd->sb_mi_bd.mi_col_end + 1) * MI_SIZE - pxl_col);
+ const int high = AOMMIN(mi_size_high[bsize] * MI_SIZE,
+ (xd->sb_mi_bd.mi_row_end + 1) * MI_SIZE - pxl_row);
+ const int ss_x = xd->plane[plane].subsampling_x;
+ const int ss_y = xd->plane[plane].subsampling_y;
+ int row_offset = (pxl_row - xd->sb_mi_bd.mi_row_begin * MI_SIZE) >> ss_y;
+ int col_offset = (pxl_col - xd->sb_mi_bd.mi_col_begin * MI_SIZE) >> ss_x;
+ int dst_stride = xd->ncobmc_pred_buf_stride[plane];
+ int dst_offset = row_offset * dst_stride + col_offset;
+ int src_stride = src->stride;
+
+ int r, c;
+ int64_t tmp, error = 0;
+
+ for (r = 0; r < (high >> ss_y); ++r) {
+ for (c = 0; c < (wide >> ss_x); ++c) {
+ tmp = xd->ncobmc_pred_buf[plane][r * dst_stride + c + dst_offset] -
+ src->buf[r * src_stride + c];
+ error += tmp * tmp;
+ }
+ }
+ return error;
+}
+
+int get_ncobmc_mode(const AV1_COMP *const cpi, MACROBLOCK *const x,
+ MACROBLOCKD *xd, int mi_row, int mi_col, int bsize) {
+ const AV1_COMMON *const cm = &cpi->common;
+#if CONFIG_HIGHBITDEPTH
+ DECLARE_ALIGNED(16, uint8_t, tmp_buf_0[2 * MAX_MB_PLANE * MAX_SB_SQUARE]);
+ DECLARE_ALIGNED(16, uint8_t, tmp_buf_1[2 * MAX_MB_PLANE * MAX_SB_SQUARE]);
+ DECLARE_ALIGNED(16, uint8_t, tmp_buf_2[2 * MAX_MB_PLANE * MAX_SB_SQUARE]);
+ DECLARE_ALIGNED(16, uint8_t, tmp_buf_3[2 * MAX_MB_PLANE * MAX_SB_SQUARE]);
+#else
+ DECLARE_ALIGNED(16, uint8_t, tmp_buf_0[MAX_MB_PLANE * MAX_SB_SQUARE]);
+ DECLARE_ALIGNED(16, uint8_t, tmp_buf_1[MAX_MB_PLANE * MAX_SB_SQUARE]);
+ DECLARE_ALIGNED(16, uint8_t, tmp_buf_2[MAX_MB_PLANE * MAX_SB_SQUARE]);
+ DECLARE_ALIGNED(16, uint8_t, tmp_buf_3[MAX_MB_PLANE * MAX_SB_SQUARE]);
+#endif
+ uint8_t *pred_buf[4][MAX_MB_PLANE];
+
+ // TODO(weitinglin): stride size needs to be fixed for high-bit depth
+ int pred_stride[MAX_MB_PLANE] = { MAX_SB_SIZE, MAX_SB_SIZE, MAX_SB_SIZE };
+
+ // target block in pxl
+ int pxl_row = mi_row << MI_SIZE_LOG2;
+ int pxl_col = mi_col << MI_SIZE_LOG2;
+ int64_t error, best_error = INT64_MAX;
+ int plane, tmp_mode, best_mode = 0;
+#if CONFIG_HIGHBITDEPTH
+ if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
+ int len = sizeof(uint16_t);
+ ASSIGN_ALIGNED_PTRS_HBD(pred_buf[0], tmp_buf_0, MAX_SB_SQUARE, len);
+ ASSIGN_ALIGNED_PTRS_HBD(pred_buf[1], tmp_buf_0, MAX_SB_SQUARE, len);
+ ASSIGN_ALIGNED_PTRS_HBD(pred_buf[2], tmp_buf_0, MAX_SB_SQUARE, len);
+ ASSIGN_ALIGNED_PTRS_HBD(pred_buf[3], tmp_buf_0, MAX_SB_SQUARE, len);
+ } else {
+#endif // CONFIG_HIGHBITDEPTH
+ ASSIGN_ALIGNED_PTRS(pred_buf[0], tmp_buf_0, MAX_SB_SQUARE);
+ ASSIGN_ALIGNED_PTRS(pred_buf[1], tmp_buf_1, MAX_SB_SQUARE);
+ ASSIGN_ALIGNED_PTRS(pred_buf[2], tmp_buf_2, MAX_SB_SQUARE);
+ ASSIGN_ALIGNED_PTRS(pred_buf[3], tmp_buf_3, MAX_SB_SQUARE);
+#if CONFIG_HIGHBITDEPTH
+ }
+#endif
+
+ av1_get_ext_blk_preds(cm, xd, bsize, mi_row, mi_col, pred_buf, pred_stride);
+ av1_get_ori_blk_pred(cm, xd, bsize, mi_row, mi_col, pred_buf[3], pred_stride);
+
+ for (tmp_mode = 0; tmp_mode < MAX_NCOBMC_MODES; ++tmp_mode) {
+ error = 0;
+ for (plane = 0; plane < MAX_MB_PLANE; ++plane) {
+ build_ncobmc_intrpl_pred(cm, xd, plane, pxl_row, pxl_col, bsize, pred_buf,
+ pred_stride, tmp_mode);
+ error += get_ncobmc_error(xd, pxl_row, pxl_col, bsize, plane,
+ &x->plane[plane].src);
+ }
+ if (error < best_error) {
+ best_mode = tmp_mode;
+ best_error = error;
+ }
+ }
+
+ for (plane = 0; plane < MAX_MB_PLANE; ++plane) {
+ build_ncobmc_intrpl_pred(cm, xd, plane, pxl_row, pxl_col, bsize, pred_buf,
+ pred_stride, best_mode);
+ }
+
+ return best_mode;
+}
+
+#endif // CONFIG_NCOBMC_ADAPT_WEIGHT
#endif // CONFIG_MOTION_VAR
diff --git a/av1/encoder/rdopt.h b/av1/encoder/rdopt.h
index 4923952..bcad8f8 100644
--- a/av1/encoder/rdopt.h
+++ b/av1/encoder/rdopt.h
@@ -144,4 +144,20 @@
const MACROBLOCKD *xd, BLOCK_SIZE bsize, int plane,
TX_SIZE tx_size, TX_TYPE tx_type);
+#if CONFIG_NCOBMC_ADAPT_WEIGHT
+void av1_check_ncobmc_adapt_weight_rd(const struct AV1_COMP *cpi,
+ struct macroblock *x, int mi_row,
+ int mi_col);
+
+int get_ncobmc_mode(const AV1_COMP *const cpi, MACROBLOCK *const x,
+ MACROBLOCKD *xd, int mi_row, int mi_col, int bsize);
+
+void av1_setup_src_planes_pxl(MACROBLOCK *x, const YV12_BUFFER_CONFIG *src,
+ int pxl_row, int pxl_col);
+
+void rebuild_ncobmc_mode(const AV1_COMP *const cpi, MACROBLOCK *const x,
+ MACROBLOCKD *xd, int mi_row, int mi_col, int bsize,
+ int xd_mi_offset, NCOBMC_MODE best_mode, int rebuild);
+#endif
+
#endif // AV1_ENCODER_RDOPT_H_