Use weighted averaging of CDFs in enc row-mt
In case of row-based multi-threading of encoder, since we always
keep a top-right sync, we can average the top-right SB's CDFs and
the left SB's CDFs and use the same for current SB's encoding to
improve the performance.
Using weighted averaging i.e., giving more weight (75%) to left
SB’s CDF and less weight (25%) to top-right SB’s CDF, further
helps in reducing the BD rate gap between enc row-mt and tile-mt.
STATS_CHANGED for row-mt.
Change-Id: Ic66a4710b198d649f6077d7437d789c96e743d4a
diff --git a/av1/encoder/av1_multi_thread.c b/av1/encoder/av1_multi_thread.c
index a0c556e..ff4c4ac 100644
--- a/av1/encoder/av1_multi_thread.c
+++ b/av1/encoder/av1_multi_thread.c
@@ -35,6 +35,12 @@
&cpi->tile_data[tile_row * multi_thread_ctxt->allocated_tile_cols +
tile_col];
av1_row_mt_sync_mem_alloc(&this_tile->row_mt_sync, cm, max_sb_rows);
+ CHECK_MEM_ERROR(
+ cm, this_tile->row_ctx,
+ (FRAME_CONTEXT *)aom_memalign(
+ 16, AOMMAX(1, (av1_get_sb_cols_in_tile(cm, this_tile->tile_info) -
+ 1)) *
+ sizeof(*this_tile->row_ctx)));
}
}
}
@@ -53,6 +59,7 @@
&cpi->tile_data[tile_row * multi_thread_ctxt->allocated_tile_cols +
tile_col];
av1_row_mt_sync_mem_dealloc(&this_tile->row_mt_sync);
+ aom_free(this_tile->row_ctx);
}
}
multi_thread_ctxt->allocated_sb_rows = 0;
diff --git a/av1/encoder/block.h b/av1/encoder/block.h
index 0471de5..5df6677 100644
--- a/av1/encoder/block.h
+++ b/av1/encoder/block.h
@@ -277,7 +277,7 @@
CONV_BUF_TYPE *tmp_conv_dst;
uint8_t *tmp_obmc_bufs[2];
- FRAME_CONTEXT *backup_tile_ctx;
+ FRAME_CONTEXT *row_ctx;
// This context will be used to update color_map_cdf pointer which would be
// used during pack bitstream. For single thread and tile-multithreading case
// this ponter will be same as xd->tile_ctx, but for the case of row-mt:
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index 38da2a1..85caff65 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -5317,6 +5317,182 @@
}
}
+#define AVG_CDF_WEIGHT_LEFT 3
+#define AVG_CDF_WEIGHT_TOP_RIGHT 1
+
+static void avg_cdf_symbol(aom_cdf_prob *cdf_ptr_left, aom_cdf_prob *cdf_ptr_tr,
+ int num_cdfs, int cdf_stride, int nsymbs,
+ int wt_left, int wt_tr) {
+ for (int i = 0; i < num_cdfs; i++) {
+ for (int j = 0; j <= nsymbs; j++) {
+ cdf_ptr_left[i * cdf_stride + j] =
+ (aom_cdf_prob)(((int)cdf_ptr_left[i * cdf_stride + j] * wt_left +
+ (int)cdf_ptr_tr[i * cdf_stride + j] * wt_tr +
+ ((wt_left + wt_tr) / 2)) /
+ (wt_left + wt_tr));
+ assert(cdf_ptr_left[i * cdf_stride + j] >= 0 &&
+ cdf_ptr_left[i * cdf_stride + j] < CDF_PROB_TOP);
+ }
+ }
+}
+
+#define AVERAGE_CDF(cname_left, cname_tr, nsymbs) \
+ AVG_CDF_STRIDE(cname_left, cname_tr, nsymbs, CDF_SIZE(nsymbs))
+
+#define AVG_CDF_STRIDE(cname_left, cname_tr, nsymbs, cdf_stride) \
+ do { \
+ aom_cdf_prob *cdf_ptr_left = (aom_cdf_prob *)cname_left; \
+ aom_cdf_prob *cdf_ptr_tr = (aom_cdf_prob *)cname_tr; \
+ int array_size = (int)sizeof(cname_left) / sizeof(aom_cdf_prob); \
+ int num_cdfs = array_size / cdf_stride; \
+ avg_cdf_symbol(cdf_ptr_left, cdf_ptr_tr, num_cdfs, cdf_stride, nsymbs, \
+ wt_left, wt_tr); \
+ } while (0)
+
+static void avg_nmv(nmv_context *nmv_left, nmv_context *nmv_tr, int wt_left,
+ int wt_tr) {
+ AVERAGE_CDF(nmv_left->joints_cdf, nmv_tr->joints_cdf, 4);
+ for (int i = 0; i < 2; i++) {
+ AVERAGE_CDF(nmv_left->comps[i].classes_cdf, nmv_tr->comps[i].classes_cdf,
+ MV_CLASSES);
+ AVERAGE_CDF(nmv_left->comps[i].class0_fp_cdf,
+ nmv_tr->comps[i].class0_fp_cdf, MV_FP_SIZE);
+ AVERAGE_CDF(nmv_left->comps[i].fp_cdf, nmv_tr->comps[i].fp_cdf, MV_FP_SIZE);
+ AVERAGE_CDF(nmv_left->comps[i].sign_cdf, nmv_tr->comps[i].sign_cdf, 2);
+ AVERAGE_CDF(nmv_left->comps[i].class0_hp_cdf,
+ nmv_tr->comps[i].class0_hp_cdf, 2);
+ AVERAGE_CDF(nmv_left->comps[i].hp_cdf, nmv_tr->comps[i].hp_cdf, 2);
+ AVERAGE_CDF(nmv_left->comps[i].class0_cdf, nmv_tr->comps[i].class0_cdf,
+ CLASS0_SIZE);
+ AVERAGE_CDF(nmv_left->comps[i].bits_cdf, nmv_tr->comps[i].bits_cdf, 2);
+ }
+}
+
+// In case of row-based multi-threading of encoder, since we always
+// keep a top - right sync, we can average the top - right SB's CDFs and
+// the left SB's CDFs and use the same for current SB's encoding to
+// improve the performance. This function facilitates the averaging
+// of CDF and used only when row-mt is enabled in encoder.
+static void avg_cdf_symbols(FRAME_CONTEXT *ctx_left, FRAME_CONTEXT *ctx_tr,
+ int wt_left, int wt_tr) {
+ AVERAGE_CDF(ctx_left->txb_skip_cdf, ctx_tr->txb_skip_cdf, 2);
+ AVERAGE_CDF(ctx_left->eob_extra_cdf, ctx_tr->eob_extra_cdf, 2);
+ AVERAGE_CDF(ctx_left->dc_sign_cdf, ctx_tr->dc_sign_cdf, 2);
+ AVERAGE_CDF(ctx_left->eob_flag_cdf16, ctx_tr->eob_flag_cdf16, 5);
+ AVERAGE_CDF(ctx_left->eob_flag_cdf32, ctx_tr->eob_flag_cdf32, 6);
+ AVERAGE_CDF(ctx_left->eob_flag_cdf64, ctx_tr->eob_flag_cdf64, 7);
+ AVERAGE_CDF(ctx_left->eob_flag_cdf128, ctx_tr->eob_flag_cdf128, 8);
+ AVERAGE_CDF(ctx_left->eob_flag_cdf256, ctx_tr->eob_flag_cdf256, 9);
+ AVERAGE_CDF(ctx_left->eob_flag_cdf512, ctx_tr->eob_flag_cdf512, 10);
+ AVERAGE_CDF(ctx_left->eob_flag_cdf1024, ctx_tr->eob_flag_cdf1024, 11);
+ AVERAGE_CDF(ctx_left->coeff_base_eob_cdf, ctx_tr->coeff_base_eob_cdf, 3);
+ AVERAGE_CDF(ctx_left->coeff_base_cdf, ctx_tr->coeff_base_cdf, 4);
+ AVERAGE_CDF(ctx_left->coeff_br_cdf, ctx_tr->coeff_br_cdf, BR_CDF_SIZE);
+ AVERAGE_CDF(ctx_left->newmv_cdf, ctx_tr->newmv_cdf, 2);
+ AVERAGE_CDF(ctx_left->zeromv_cdf, ctx_tr->zeromv_cdf, 2);
+ AVERAGE_CDF(ctx_left->refmv_cdf, ctx_tr->refmv_cdf, 2);
+ AVERAGE_CDF(ctx_left->drl_cdf, ctx_tr->drl_cdf, 2);
+ AVERAGE_CDF(ctx_left->inter_compound_mode_cdf,
+ ctx_tr->inter_compound_mode_cdf, INTER_COMPOUND_MODES);
+ AVERAGE_CDF(ctx_left->compound_type_cdf, ctx_tr->compound_type_cdf,
+ COMPOUND_TYPES - 1);
+ AVERAGE_CDF(ctx_left->wedge_idx_cdf, ctx_tr->wedge_idx_cdf, 16);
+ AVERAGE_CDF(ctx_left->interintra_cdf, ctx_tr->interintra_cdf, 2);
+ AVERAGE_CDF(ctx_left->wedge_interintra_cdf, ctx_tr->wedge_interintra_cdf, 2);
+ AVERAGE_CDF(ctx_left->interintra_mode_cdf, ctx_tr->interintra_mode_cdf,
+ INTERINTRA_MODES);
+ AVERAGE_CDF(ctx_left->motion_mode_cdf, ctx_tr->motion_mode_cdf, MOTION_MODES);
+ AVERAGE_CDF(ctx_left->obmc_cdf, ctx_tr->obmc_cdf, 2);
+ AVERAGE_CDF(ctx_left->palette_y_size_cdf, ctx_tr->palette_y_size_cdf,
+ PALETTE_SIZES);
+ AVERAGE_CDF(ctx_left->palette_uv_size_cdf, ctx_tr->palette_uv_size_cdf,
+ PALETTE_SIZES);
+ for (int j = 0; j < PALETTE_SIZES; j++) {
+ int nsymbs = j + PALETTE_MIN_SIZE;
+ AVG_CDF_STRIDE(ctx_left->palette_y_color_index_cdf[j],
+ ctx_tr->palette_y_color_index_cdf[j], nsymbs,
+ CDF_SIZE(PALETTE_COLORS));
+ AVG_CDF_STRIDE(ctx_left->palette_uv_color_index_cdf[j],
+ ctx_tr->palette_uv_color_index_cdf[j], nsymbs,
+ CDF_SIZE(PALETTE_COLORS));
+ }
+ AVERAGE_CDF(ctx_left->palette_y_mode_cdf, ctx_tr->palette_y_mode_cdf, 2);
+ AVERAGE_CDF(ctx_left->palette_uv_mode_cdf, ctx_tr->palette_uv_mode_cdf, 2);
+ AVERAGE_CDF(ctx_left->comp_inter_cdf, ctx_tr->comp_inter_cdf, 2);
+ AVERAGE_CDF(ctx_left->single_ref_cdf, ctx_tr->single_ref_cdf, 2);
+ AVERAGE_CDF(ctx_left->comp_ref_type_cdf, ctx_tr->comp_ref_type_cdf, 2);
+ AVERAGE_CDF(ctx_left->uni_comp_ref_cdf, ctx_tr->uni_comp_ref_cdf, 2);
+ AVERAGE_CDF(ctx_left->comp_ref_cdf, ctx_tr->comp_ref_cdf, 2);
+ AVERAGE_CDF(ctx_left->comp_bwdref_cdf, ctx_tr->comp_bwdref_cdf, 2);
+ AVERAGE_CDF(ctx_left->txfm_partition_cdf, ctx_tr->txfm_partition_cdf, 2);
+ AVERAGE_CDF(ctx_left->compound_index_cdf, ctx_tr->compound_index_cdf, 2);
+ AVERAGE_CDF(ctx_left->comp_group_idx_cdf, ctx_tr->comp_group_idx_cdf, 2);
+ AVERAGE_CDF(ctx_left->skip_mode_cdfs, ctx_tr->skip_mode_cdfs, 2);
+ AVERAGE_CDF(ctx_left->skip_cdfs, ctx_tr->skip_cdfs, 2);
+ AVERAGE_CDF(ctx_left->intra_inter_cdf, ctx_tr->intra_inter_cdf, 2);
+ avg_nmv(&ctx_left->nmvc, &ctx_tr->nmvc, wt_left, wt_tr);
+ avg_nmv(&ctx_left->ndvc, &ctx_tr->ndvc, wt_left, wt_tr);
+ AVERAGE_CDF(ctx_left->intrabc_cdf, ctx_tr->intrabc_cdf, 2);
+ AVERAGE_CDF(ctx_left->seg.tree_cdf, ctx_tr->seg.tree_cdf, MAX_SEGMENTS);
+ AVERAGE_CDF(ctx_left->seg.pred_cdf, ctx_tr->seg.pred_cdf, 2);
+ AVERAGE_CDF(ctx_left->seg.spatial_pred_seg_cdf,
+ ctx_tr->seg.spatial_pred_seg_cdf, MAX_SEGMENTS);
+ AVERAGE_CDF(ctx_left->filter_intra_cdfs, ctx_tr->filter_intra_cdfs, 2);
+ AVERAGE_CDF(ctx_left->filter_intra_mode_cdf, ctx_tr->filter_intra_mode_cdf,
+ FILTER_INTRA_MODES);
+ AVERAGE_CDF(ctx_left->switchable_restore_cdf, ctx_tr->switchable_restore_cdf,
+ RESTORE_SWITCHABLE_TYPES);
+ AVERAGE_CDF(ctx_left->wiener_restore_cdf, ctx_tr->wiener_restore_cdf, 2);
+ AVERAGE_CDF(ctx_left->sgrproj_restore_cdf, ctx_tr->sgrproj_restore_cdf, 2);
+ AVERAGE_CDF(ctx_left->y_mode_cdf, ctx_tr->y_mode_cdf, INTRA_MODES);
+ AVG_CDF_STRIDE(ctx_left->uv_mode_cdf[0], ctx_tr->uv_mode_cdf[0],
+ UV_INTRA_MODES - 1, CDF_SIZE(UV_INTRA_MODES));
+ AVERAGE_CDF(ctx_left->uv_mode_cdf[1], ctx_tr->uv_mode_cdf[1], UV_INTRA_MODES);
+ for (int i = 0; i < PARTITION_CONTEXTS; i++) {
+ if (i < 4) {
+ AVG_CDF_STRIDE(ctx_left->partition_cdf[i], ctx_tr->partition_cdf[i], 4,
+ CDF_SIZE(10));
+ } else if (i < 16) {
+ AVERAGE_CDF(ctx_left->partition_cdf[i], ctx_tr->partition_cdf[i], 10);
+ } else {
+ AVG_CDF_STRIDE(ctx_left->partition_cdf[i], ctx_tr->partition_cdf[i], 8,
+ CDF_SIZE(10));
+ }
+ }
+ AVERAGE_CDF(ctx_left->switchable_interp_cdf, ctx_tr->switchable_interp_cdf,
+ SWITCHABLE_FILTERS);
+ AVERAGE_CDF(ctx_left->kf_y_cdf, ctx_tr->kf_y_cdf, INTRA_MODES);
+ AVERAGE_CDF(ctx_left->angle_delta_cdf, ctx_tr->angle_delta_cdf,
+ 2 * MAX_ANGLE_DELTA + 1);
+ AVG_CDF_STRIDE(ctx_left->tx_size_cdf[0], ctx_tr->tx_size_cdf[0], MAX_TX_DEPTH,
+ CDF_SIZE(MAX_TX_DEPTH + 1));
+ AVERAGE_CDF(ctx_left->tx_size_cdf[1], ctx_tr->tx_size_cdf[1],
+ MAX_TX_DEPTH + 1);
+ AVERAGE_CDF(ctx_left->tx_size_cdf[2], ctx_tr->tx_size_cdf[2],
+ MAX_TX_DEPTH + 1);
+ AVERAGE_CDF(ctx_left->tx_size_cdf[3], ctx_tr->tx_size_cdf[3],
+ MAX_TX_DEPTH + 1);
+ AVERAGE_CDF(ctx_left->delta_q_cdf, ctx_tr->delta_q_cdf, DELTA_Q_PROBS + 1);
+ AVERAGE_CDF(ctx_left->delta_lf_cdf, ctx_tr->delta_lf_cdf, DELTA_LF_PROBS + 1);
+ for (int i = 0; i < FRAME_LF_COUNT; i++) {
+ AVERAGE_CDF(ctx_left->delta_lf_multi_cdf[i], ctx_tr->delta_lf_multi_cdf[i],
+ DELTA_LF_PROBS + 1);
+ }
+ AVG_CDF_STRIDE(ctx_left->intra_ext_tx_cdf[1], ctx_tr->intra_ext_tx_cdf[1], 7,
+ CDF_SIZE(TX_TYPES));
+ AVG_CDF_STRIDE(ctx_left->intra_ext_tx_cdf[2], ctx_tr->intra_ext_tx_cdf[2], 5,
+ CDF_SIZE(TX_TYPES));
+ AVG_CDF_STRIDE(ctx_left->inter_ext_tx_cdf[1], ctx_tr->inter_ext_tx_cdf[1], 16,
+ CDF_SIZE(TX_TYPES));
+ AVG_CDF_STRIDE(ctx_left->inter_ext_tx_cdf[2], ctx_tr->inter_ext_tx_cdf[2], 12,
+ CDF_SIZE(TX_TYPES));
+ AVG_CDF_STRIDE(ctx_left->inter_ext_tx_cdf[3], ctx_tr->inter_ext_tx_cdf[3], 2,
+ CDF_SIZE(TX_TYPES));
+ AVERAGE_CDF(ctx_left->cfl_sign_cdf, ctx_tr->cfl_sign_cdf, CFL_JOINT_SIGNS);
+ AVERAGE_CDF(ctx_left->cfl_alpha_cdf, ctx_tr->cfl_alpha_cdf,
+ CFL_ALPHABET_SIZE);
+}
+
static void encode_rd_sb_row(AV1_COMP *cpi, ThreadData *td,
TileDataEnc *tile_data, int mi_row,
TOKENEXTRA **tp) {
@@ -5350,10 +5526,22 @@
mi_col < tile_info->mi_col_end; mi_col += mib_size, sb_col_in_tile++) {
(*(cpi->row_mt_sync_read_ptr))(&tile_data->row_mt_sync, sb_row,
sb_col_in_tile);
- if ((cpi->row_mt == 1) && (tile_info->mi_col_start == mi_col) &&
- (tile_info->mi_row_start != mi_row)) {
- // restore frame context of 1st column sb
- memcpy(xd->tile_ctx, x->backup_tile_ctx, sizeof(*xd->tile_ctx));
+ if ((cpi->row_mt == 1) && (tile_info->mi_row_start != mi_row)) {
+ if ((tile_info->mi_col_start == mi_col)) {
+ // restore frame context of 1st column sb
+ memcpy(xd->tile_ctx, x->row_ctx, sizeof(*xd->tile_ctx));
+ } else {
+ if (tile_data->allow_update_cdf) {
+ int wt_left = AVG_CDF_WEIGHT_LEFT;
+ int wt_tr = AVG_CDF_WEIGHT_TOP_RIGHT;
+ if (tile_info->mi_col_end > (mi_col + mib_size))
+ avg_cdf_symbols(xd->tile_ctx, x->row_ctx + sb_col_in_tile, wt_left,
+ wt_tr);
+ else
+ avg_cdf_symbols(xd->tile_ctx, x->row_ctx + sb_col_in_tile - 1,
+ wt_left, wt_tr);
+ }
+ }
}
av1_fill_coeff_costs(&td->mb, xd->tile_ctx, num_planes);
av1_fill_mode_rates(cm, x, xd->tile_ctx);
@@ -5459,11 +5647,12 @@
av1_inter_mode_data_fit(tile_data, x->rdmult);
}
#endif
- if (cpi->row_mt == 1) {
- int update_context = 0;
- update_context = sb_cols_in_tile == 1 || sb_col_in_tile == 1;
- if (update_context)
- memcpy(x->backup_tile_ctx, xd->tile_ctx, sizeof(*xd->tile_ctx));
+ if ((cpi->row_mt == 1) && (tile_info->mi_row_end > (mi_row + mib_size))) {
+ if (sb_cols_in_tile == 1)
+ memcpy(x->row_ctx, xd->tile_ctx, sizeof(*xd->tile_ctx));
+ else if (sb_col_in_tile >= 1)
+ memcpy(x->row_ctx + sb_col_in_tile - 1, xd->tile_ctx,
+ sizeof(*xd->tile_ctx));
}
(*(cpi->row_mt_sync_write_ptr))(&tile_data->row_mt_sync, sb_row,
sb_col_in_tile, sb_cols_in_tile);
@@ -5654,7 +5843,6 @@
cpi->td.intrabc_used = 0;
cpi->td.mb.e_mbd.tile_ctx = &this_tile->tctx;
cpi->td.mb.tile_pb_ctx = &this_tile->tctx;
- cpi->td.mb.backup_tile_ctx = &this_tile->backup_tctx;
av1_encode_tile(cpi, &cpi->td, tile_row, tile_col);
cpi->intrabc_used |= cpi->td.intrabc_used;
}
diff --git a/av1/encoder/encoder.h b/av1/encoder/encoder.h
index 2a522e2..609f6a0 100644
--- a/av1/encoder/encoder.h
+++ b/av1/encoder/encoder.h
@@ -511,7 +511,7 @@
int ex_search_count;
CFL_CTX cfl;
DECLARE_ALIGNED(16, FRAME_CONTEXT, tctx);
- DECLARE_ALIGNED(16, FRAME_CONTEXT, backup_tctx);
+ FRAME_CONTEXT *row_ctx;
uint8_t allow_update_cdf;
#if CONFIG_COLLECT_INTER_MODE_RD_STATS
InterModeRdModel inter_mode_rd_models[BLOCK_SIZES_ALL];
diff --git a/av1/encoder/ethread.c b/av1/encoder/ethread.c
index 3f720ee..a82f2af 100644
--- a/av1/encoder/ethread.c
+++ b/av1/encoder/ethread.c
@@ -310,7 +310,7 @@
td->mb.e_mbd.tile_ctx = td->tctx;
td->mb.tile_pb_ctx = &this_tile->tctx;
- td->mb.backup_tile_ctx = &this_tile->backup_tctx;
+ td->mb.row_ctx = this_tile->row_ctx;
if (current_mi_row == this_tile->tile_info.mi_row_start)
memcpy(td->mb.e_mbd.tile_ctx, &this_tile->tctx, sizeof(FRAME_CONTEXT));
av1_init_above_context(cm, &td->mb.e_mbd, tile_row);
@@ -355,7 +355,6 @@
&cpi->tile_data[tile_row * cm->tile_cols + tile_col];
thread_data->td->mb.e_mbd.tile_ctx = &this_tile->tctx;
thread_data->td->mb.tile_pb_ctx = &this_tile->tctx;
- thread_data->td->mb.backup_tile_ctx = &this_tile->backup_tctx;
av1_encode_tile(cpi, thread_data->td, tile_row, tile_col);
}