JNT_COMP: 3. rd select the best weight
Select the best compound_idx in rd.
The rate/cost for compound_idx and their ctx will be in patch 4.
But there's a bug for now if we don't encode one more time using the
selected compound_idx. It remains a issue to be solved in the future.
Change-Id: I5e1ba51da2b6ab5bacd8aba752dda43bd2257014
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 042448e..ce4c03d 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -5795,43 +5795,48 @@
static void jnt_comp_weight_assign(const AV1_COMMON *cm,
const MB_MODE_INFO *mbmi, int order_idx,
uint8_t *second_pred) {
- int bck_idx = cm->frame_refs[mbmi->ref_frame[0] - LAST_FRAME].idx;
- int fwd_idx = cm->frame_refs[mbmi->ref_frame[1] - LAST_FRAME].idx;
- int bck_frame_index = 0, fwd_frame_index = 0;
- int cur_frame_index = cm->cur_frame->cur_frame_offset;
-
- if (bck_idx >= 0) {
- bck_frame_index = cm->buffer_pool->frame_bufs[bck_idx].cur_frame_offset;
- }
-
- if (fwd_idx >= 0) {
- fwd_frame_index = cm->buffer_pool->frame_bufs[fwd_idx].cur_frame_offset;
- }
-
- const double fwd = abs(fwd_frame_index - cur_frame_index);
- const double bck = abs(cur_frame_index - bck_frame_index);
- int order;
- double ratio;
-
- if (COMPOUND_WEIGHT_MODE == DIST) {
- if (fwd > bck) {
- ratio = (bck != 0) ? fwd / bck : 5.0;
- order = 0;
- } else {
- ratio = (fwd != 0) ? bck / fwd : 5.0;
- order = 1;
- }
- int quant_dist_idx;
- for (quant_dist_idx = 0; quant_dist_idx < 4; ++quant_dist_idx) {
- if (ratio < quant_dist_category[quant_dist_idx]) break;
- }
- second_pred[4096] =
- quant_dist_lookup_table[order_idx][quant_dist_idx][order];
- second_pred[4097] =
- quant_dist_lookup_table[order_idx][quant_dist_idx][1 - order];
+ if (mbmi->compound_idx) {
+ second_pred[4096] = -1;
+ second_pred[4097] = -1;
} else {
- second_pred[4096] = (DIST_PRECISION >> 1);
- second_pred[4097] = (DIST_PRECISION >> 1);
+ int bck_idx = cm->frame_refs[mbmi->ref_frame[0] - LAST_FRAME].idx;
+ int fwd_idx = cm->frame_refs[mbmi->ref_frame[1] - LAST_FRAME].idx;
+ int bck_frame_index = 0, fwd_frame_index = 0;
+ int cur_frame_index = cm->cur_frame->cur_frame_offset;
+
+ if (bck_idx >= 0) {
+ bck_frame_index = cm->buffer_pool->frame_bufs[bck_idx].cur_frame_offset;
+ }
+
+ if (fwd_idx >= 0) {
+ fwd_frame_index = cm->buffer_pool->frame_bufs[fwd_idx].cur_frame_offset;
+ }
+
+ const double fwd = abs(fwd_frame_index - cur_frame_index);
+ const double bck = abs(cur_frame_index - bck_frame_index);
+ int order;
+ double ratio;
+
+ if (COMPOUND_WEIGHT_MODE == DIST) {
+ if (fwd > bck) {
+ ratio = (bck != 0) ? fwd / bck : 5.0;
+ order = 0;
+ } else {
+ ratio = (fwd != 0) ? bck / fwd : 5.0;
+ order = 1;
+ }
+ int quant_dist_idx;
+ for (quant_dist_idx = 0; quant_dist_idx < 4; ++quant_dist_idx) {
+ if (ratio < quant_dist_category[quant_dist_idx]) break;
+ }
+ second_pred[4096] =
+ quant_dist_lookup_table[order_idx][quant_dist_idx][order];
+ second_pred[4097] =
+ quant_dist_lookup_table[order_idx][quant_dist_idx][1 - order];
+ } else {
+ second_pred[4096] = (DIST_PRECISION >> 1);
+ second_pred[4097] = (DIST_PRECISION >> 1);
+ }
}
}
#endif // CONFIG_JNT_COMP
@@ -10217,6 +10222,130 @@
}
}
}
+#if CONFIG_JNT_COMP
+ {
+ int cum_rate = rate2;
+ MB_MODE_INFO backup_mbmi = *mbmi;
+
+ int_mv backup_frame_mv[MB_MODE_COUNT][TOTAL_REFS_PER_FRAME];
+ int_mv backup_single_newmv[TOTAL_REFS_PER_FRAME];
+ int backup_single_newmv_rate[TOTAL_REFS_PER_FRAME];
+ int64_t backup_modelled_rd[MB_MODE_COUNT][TOTAL_REFS_PER_FRAME];
+
+ memcpy(backup_frame_mv, frame_mv, sizeof(frame_mv));
+ memcpy(backup_single_newmv, single_newmv, sizeof(single_newmv));
+ memcpy(backup_single_newmv_rate, single_newmv_rate,
+ sizeof(single_newmv_rate));
+ memcpy(backup_modelled_rd, modelled_rd, sizeof(modelled_rd));
+
+ InterpFilters backup_interp_filters = mbmi->interp_filters;
+
+ for (int comp_idx = 0; comp_idx < 1 + has_second_ref(mbmi);
+ ++comp_idx) {
+ RD_STATS rd_stats, rd_stats_y, rd_stats_uv;
+ av1_init_rd_stats(&rd_stats);
+ av1_init_rd_stats(&rd_stats_y);
+ av1_init_rd_stats(&rd_stats_uv);
+ rd_stats.rate = cum_rate;
+
+ memcpy(frame_mv, backup_frame_mv, sizeof(frame_mv));
+ memcpy(single_newmv, backup_single_newmv, sizeof(single_newmv));
+ memcpy(single_newmv_rate, backup_single_newmv_rate,
+ sizeof(single_newmv_rate));
+ memcpy(modelled_rd, backup_modelled_rd, sizeof(modelled_rd));
+
+ mbmi->interp_filters = backup_interp_filters;
+
+ int dummy_disable_skip = 0;
+
+ // Point to variables that are maintained between loop iterations
+ args.single_newmv = single_newmv;
+ args.single_newmv_rate = single_newmv_rate;
+ args.modelled_rd = modelled_rd;
+ mbmi->compound_idx = comp_idx;
+
+ int64_t tmp_rd = handle_inter_mode(
+ cpi, x, bsize, &rd_stats, &rd_stats_y, &rd_stats_uv,
+ &dummy_disable_skip, frame_mv, mi_row, mi_col, &args, best_rd);
+
+ if (tmp_rd < INT64_MAX) {
+ if (RDCOST(x->rdmult, rd_stats.rate, rd_stats.dist) <
+ RDCOST(x->rdmult, 0, rd_stats.sse))
+ tmp_rd =
+ RDCOST(x->rdmult, rd_stats.rate + x->skip_cost[skip_ctx][0],
+ rd_stats.dist);
+ else
+ tmp_rd = RDCOST(x->rdmult,
+ rd_stats.rate + x->skip_cost[skip_ctx][1] -
+ rd_stats_y.rate - rd_stats_uv.rate,
+ rd_stats.sse);
+ }
+
+ if (tmp_rd < this_rd) {
+ this_rd = tmp_rd;
+ rate2 = rd_stats.rate;
+ skippable = rd_stats.skip;
+ distortion2 = rd_stats.dist;
+ total_sse = rd_stats.sse;
+ rate_y = rd_stats_y.rate;
+ rate_uv = rd_stats_uv.rate;
+ disable_skip = dummy_disable_skip;
+ backup_mbmi = *mbmi;
+ }
+ }
+ *mbmi = backup_mbmi;
+
+ // TODO(chengchen): Redo encoding use the selected compound_idx
+ // But ideally, this is unnecessary
+ {
+ RD_STATS rd_stats, rd_stats_y, rd_stats_uv;
+ av1_init_rd_stats(&rd_stats);
+ av1_init_rd_stats(&rd_stats_y);
+ av1_init_rd_stats(&rd_stats_uv);
+ rd_stats.rate = cum_rate;
+
+ memcpy(frame_mv, backup_frame_mv, sizeof(frame_mv));
+ memcpy(single_newmv, backup_single_newmv, sizeof(single_newmv));
+ memcpy(single_newmv_rate, backup_single_newmv_rate,
+ sizeof(single_newmv_rate));
+ memcpy(modelled_rd, backup_modelled_rd, sizeof(modelled_rd));
+
+ mbmi->interp_filters = backup_interp_filters;
+
+ int dummy_disable_skip = 0;
+
+ args.single_newmv = single_newmv;
+ args.single_newmv_rate = single_newmv_rate;
+ args.modelled_rd = modelled_rd;
+
+ int64_t tmp_rd = handle_inter_mode(
+ cpi, x, bsize, &rd_stats, &rd_stats_y, &rd_stats_uv,
+ &dummy_disable_skip, frame_mv, mi_row, mi_col, &args, best_rd);
+
+ if (tmp_rd < INT64_MAX) {
+ if (RDCOST(x->rdmult, rd_stats.rate, rd_stats.dist) <
+ RDCOST(x->rdmult, 0, rd_stats.sse))
+ tmp_rd =
+ RDCOST(x->rdmult, rd_stats.rate + x->skip_cost[skip_ctx][0],
+ rd_stats.dist);
+ else
+ tmp_rd = RDCOST(x->rdmult,
+ rd_stats.rate + x->skip_cost[skip_ctx][1] -
+ rd_stats_y.rate - rd_stats_uv.rate,
+ rd_stats.sse);
+ }
+
+ this_rd = tmp_rd;
+ rate2 = rd_stats.rate;
+ skippable = rd_stats.skip;
+ distortion2 = rd_stats.dist;
+ total_sse = rd_stats.sse;
+ rate_y = rd_stats_y.rate;
+ rate_uv = rd_stats_uv.rate;
+ disable_skip = dummy_disable_skip;
+ }
+ }
+#else // CONFIG_JNT_COMP
{
RD_STATS rd_stats, rd_stats_y, rd_stats_uv;
av1_init_rd_stats(&rd_stats);
@@ -10240,6 +10369,7 @@
rate_y = rd_stats_y.rate;
rate_uv = rd_stats_uv.rate;
}
+#endif // CONFIG_JNT_COMP
// TODO(jingning): This needs some refactoring to improve code quality
// and reduce redundant steps.
@@ -10293,12 +10423,22 @@
memcpy(x->blk_skip_drl[i], x->blk_skip[i],
sizeof(uint8_t) * ctx->num_4x4_blk);
- for (ref_idx = 0; ref_idx < ref_set; ++ref_idx) {
+#if CONFIG_JNT_COMP
+ for (int sidx = 0; sidx < ref_set * (1 + has_second_ref(mbmi)); ++sidx)
+#else
+ for (ref_idx = 0; ref_idx < ref_set; ++ref_idx)
+#endif // CONFIG_JNT_COMP
+ {
int64_t tmp_alt_rd = INT64_MAX;
int dummy_disable_skip = 0;
int ref;
int_mv cur_mv;
RD_STATS tmp_rd_stats, tmp_rd_stats_y, tmp_rd_stats_uv;
+#if CONFIG_JNT_COMP
+ ref_idx = sidx;
+ if (has_second_ref(mbmi)) ref_idx /= 2;
+ mbmi->compound_idx = sidx % 2;
+#endif // CONFIG_JNT_COMP
av1_invalid_rd_stats(&tmp_rd_stats);
@@ -10480,6 +10620,9 @@
for (i = 0; i < MAX_MB_PLANE; ++i)
memcpy(x->blk_skip[i], x->blk_skip_drl[i],
sizeof(uint8_t) * ctx->num_4x4_blk);
+#if CONFIG_JNT_COMP
+ *mbmi = backup_mbmi;
+#endif // CONFIG_JNT_COMP
}
mbmi_ext->ref_mvs[ref_frame][0] = backup_ref_mv[0];
if (comp_pred) mbmi_ext->ref_mvs[second_ref_frame][0] = backup_ref_mv[1];