Add comments in handle_inter_intra_mode
Added comments and did minor code refactoring in
handle_inter_intra_mode
Change-Id: I20164107c45a8fbccc81b096e7d9129b529554fb
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index badfbbf..d41d971 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -4358,48 +4358,48 @@
int *tmp_rate2, const BUFFER_SET *orig_dst) {
const int try_smooth_interintra = cpi->oxcf.enable_smooth_interintra &&
!cpi->sf.inter_sf.disable_smooth_interintra;
- const int is_wedge_used = av1_is_wedge_used(bsize);
const int try_wedge_interintra =
- is_wedge_used && enable_wedge_interintra_search(x, cpi);
+ av1_is_wedge_used(bsize) && enable_wedge_interintra_search(x, cpi);
if (!try_smooth_interintra && !try_wedge_interintra) return -1;
const AV1_COMMON *const cm = &cpi->common;
MACROBLOCKD *xd = &x->e_mbd;
int64_t rd = INT64_MAX;
- int rmode, rate_sum;
- int64_t dist_sum;
- int tmp_rate_mv = 0;
- int tmp_skip_txfm_sb;
const int bw = block_size_wide[bsize];
- int64_t tmp_skip_sse_sb;
DECLARE_ALIGNED(16, uint8_t, tmp_buf_[2 * MAX_INTERINTRA_SB_SQUARE]);
DECLARE_ALIGNED(16, uint8_t, intrapred_[2 * MAX_INTERINTRA_SB_SQUARE]);
uint8_t *tmp_buf = get_buf_by_bd(xd, tmp_buf_);
uint8_t *intrapred = get_buf_by_bd(xd, intrapred_);
const int *const interintra_mode_cost =
x->interintra_mode_cost[size_group_lookup[bsize]];
- const int_mv mv0 = mbmi->mv[0];
+ const int mi_row = xd->mi_row;
+ const int mi_col = xd->mi_col;
+
+ // Single reference inter prediction
mbmi->ref_frame[1] = NONE_FRAME;
xd->plane[0].dst.buf = tmp_buf;
xd->plane[0].dst.stride = bw;
- const int mi_row = xd->mi_row;
- const int mi_col = xd->mi_col;
av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, NULL, bsize,
AOM_PLANE_Y, AOM_PLANE_Y);
const int num_planes = av1_num_planes(cm);
+
+ // Restore the buffers for intra prediction
restore_dst_buf(xd, *orig_dst, num_planes);
mbmi->ref_frame[1] = INTRA_FRAME;
INTERINTRA_MODE best_interintra_mode =
args->inter_intra_mode[mbmi->ref_frame[0]];
+ // Compute smooth_interintra
int64_t best_interintra_rd_nowedge = INT64_MAX;
if (try_smooth_interintra) {
mbmi->use_wedge_interintra = 0;
- INTERINTRA_MODE cur_mode = 0;
+ int interintra_mode_reuse = 1;
if (cpi->sf.inter_sf.reuse_inter_intra_mode == 0 ||
best_interintra_mode == INTERINTRA_MODES) {
+ interintra_mode_reuse = 0;
int64_t best_interintra_rd = INT64_MAX;
- for (cur_mode = 0; cur_mode < INTERINTRA_MODES; ++cur_mode) {
+ for (INTERINTRA_MODE cur_mode = 0; cur_mode < INTERINTRA_MODES;
+ ++cur_mode) {
if ((!cpi->oxcf.enable_smooth_intra ||
cpi->sf.intra_sf.disable_smooth_intra) &&
cur_mode == II_SMOOTH_PRED)
@@ -4414,14 +4414,16 @@
assert(IMPLIES(!cpi->oxcf.enable_smooth_interintra ||
cpi->sf.inter_sf.disable_smooth_interintra,
best_interintra_mode != II_SMOOTH_PRED));
- rmode = interintra_mode_cost[best_interintra_mode];
- if (cur_mode == 0 || best_interintra_mode != INTERINTRA_MODES - 1) {
+ int rmode = interintra_mode_cost[best_interintra_mode];
+ // Recompute prediction if required
+ if (interintra_mode_reuse || best_interintra_mode != INTERINTRA_MODES - 1) {
mbmi->interintra_mode = best_interintra_mode;
av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
intrapred, bw);
av1_combine_interintra(xd, bsize, 0, tmp_buf, bw, intrapred, bw);
}
+ // Compute rd cost for best smooth_interintra
RD_STATS rd_stats;
int total_mode_rate;
const int64_t rd_thresh = compute_total_rate_and_rd_thresh(
@@ -4433,6 +4435,7 @@
return -1;
}
best_interintra_rd_nowedge = rd;
+ // Return early if best_interintra_rd_nowedge not good enough
if (ref_best_rd < INT64_MAX &&
(best_interintra_rd_nowedge >> INTER_INTRA_RD_THRESH_SHIFT) *
INTER_INTRA_RD_THRESH_SCALE >
@@ -4441,6 +4444,7 @@
}
}
+ // Compute wedge interintra
int64_t best_interintra_rd_wedge = INT64_MAX;
if (try_wedge_interintra) {
mbmi->use_wedge_interintra = 1;
@@ -4463,9 +4467,10 @@
best_interintra_mode = INTERINTRA_MODES - 1;
av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
intrapred, bw);
+ // Pick wedge mask based on INTERINTRA_MODES - 1
best_interintra_rd_wedge =
pick_interintra_wedge(cpi, x, bsize, intrapred_, tmp_buf_);
-
+ // Find the best interintra mode for the chosen wedge mask
for (INTERINTRA_MODE cur_mode = 0; cur_mode < INTERINTRA_MODES;
++cur_mode) {
compute_best_interintra_mode(
@@ -4476,11 +4481,13 @@
args->inter_intra_mode[mbmi->ref_frame[0]] = best_interintra_mode;
mbmi->interintra_mode = best_interintra_mode;
+ // Recompute prediction if required
if (best_interintra_mode != INTERINTRA_MODES - 1) {
av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
intrapred, bw);
}
} else {
+ // Pick wedge mask for the best interintra mode (reused)
mbmi->interintra_mode = best_interintra_mode;
av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
intrapred, bw);
@@ -4488,6 +4495,7 @@
pick_interintra_wedge(cpi, x, bsize, intrapred_, tmp_buf_);
}
} else {
+ // Pick wedge mask for the best interintra mode from smooth_interintra
best_interintra_rd_wedge =
pick_interintra_wedge(cpi, x, bsize, intrapred_, tmp_buf_);
}
@@ -4498,14 +4506,17 @@
x->wedge_interintra_cost[bsize][1];
best_interintra_rd_wedge += RDCOST(x->rdmult, rate_overhead + *rate_mv, 0);
- int_mv tmp_mv;
+ const int_mv mv0 = mbmi->mv[0];
+ int_mv tmp_mv = mv0;
rd = INT64_MAX;
- // Refine motion vector.
+ int tmp_rate_mv = 0;
+ // Refine motion vector for NEWMV case.
if (have_newmv_in_inter_mode(mbmi->mode)) {
+ int rate_sum, skip_txfm_sb;
+ int64_t dist_sum, skip_sse_sb;
// get negative of mask
const uint8_t *mask =
av1_get_contiguous_soft_mask(mbmi->interintra_wedge_index, 1, bsize);
- tmp_mv = mbmi->mv[0];
compound_single_motion_search(cpi, x, bsize, &tmp_mv.as_mv, intrapred,
mask, bw, &tmp_rate_mv, 0);
if (mbmi->mv[0].as_int != tmp_mv.as_int) {
@@ -4513,8 +4524,8 @@
av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
AOM_PLANE_Y, AOM_PLANE_Y);
model_rd_sb_fn[MODELRD_TYPE_MASKED_COMPOUND](
- cpi, bsize, x, xd, 0, 0, &rate_sum, &dist_sum, &tmp_skip_txfm_sb,
- &tmp_skip_sse_sb, NULL, NULL, NULL);
+ cpi, bsize, x, xd, 0, 0, &rate_sum, &dist_sum, &skip_txfm_sb,
+ &skip_sse_sb, NULL, NULL, NULL);
rd =
RDCOST(x->rdmult, tmp_rate_mv + rate_overhead + rate_sum, dist_sum);
}