Facilitate multiwinner mode processing framework for inter frames
This patch will facilitate multi-winner mode processing support
for Inter frames.
Change-Id: Icfa93bd5a607bcfe1c70a2c936094f18a3da7c21
diff --git a/av1/encoder/block.h b/av1/encoder/block.h
index da80a30..34df73b 100644
--- a/av1/encoder/block.h
+++ b/av1/encoder/block.h
@@ -35,11 +35,16 @@
#define MC_FLOW_BSIZE_1D 16
#define MC_FLOW_NUM_PELS (MC_FLOW_BSIZE_1D * MC_FLOW_BSIZE_1D)
#define MAX_MC_FLOW_BLK_IN_SB (MAX_SB_SIZE / MC_FLOW_BSIZE_1D)
-#define MAX_WINNER_MODE_COUNT 3
+#define MAX_WINNER_MODE_COUNT_INTRA 3
+#define MAX_WINNER_MODE_COUNT_INTER 1
typedef struct {
MB_MODE_INFO mbmi;
+ RD_STATS rd_cost;
int64_t rd;
+ int rate_y;
+ int rate_uv;
uint8_t color_index_map[64 * 64];
+ THR_MODES mode_index;
} WinnerModeStats;
typedef struct {
@@ -235,7 +240,8 @@
MB_MODE_INFO_EXT *mbmi_ext;
MB_MODE_INFO_EXT_FRAME *mbmi_ext_frame;
// Array of mode stats for winner mode processing
- WinnerModeStats winner_mode_stats[MAX_WINNER_MODE_COUNT];
+ WinnerModeStats winner_mode_stats[AOMMAX(MAX_WINNER_MODE_COUNT_INTRA,
+ MAX_WINNER_MODE_COUNT_INTER)];
int winner_mode_count;
int skip_block;
int qindex;
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 5005047..42e72ff 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -4558,36 +4558,63 @@
}
// Store best mode stats for winner mode processing
-static void store_winner_mode_stats(MACROBLOCK *x, MB_MODE_INFO *mbmi,
+static void store_winner_mode_stats(const AV1_COMMON *const cm, MACROBLOCK *x,
+ MB_MODE_INFO *mbmi, RD_STATS *rd_cost,
+ RD_STATS *rd_cost_y, RD_STATS *rd_cost_uv,
+ THR_MODES mode_index, uint8_t *color_map,
+ BLOCK_SIZE bsize, int64_t this_rd,
int enable_multiwinner_mode_process,
- uint8_t *color_map, BLOCK_SIZE bsize,
- int64_t this_rd) {
+ int txfm_search_done) {
WinnerModeStats *winner_mode_stats = x->winner_mode_stats;
int mode_idx = 0;
+ int is_palette_mode = mbmi->palette_mode_info.palette_size[PLANE_TYPE_Y] > 0;
// Mode stat is not required when multiwinner mode processing is disabled
if (!enable_multiwinner_mode_process) return;
+ // TODO(any): Winner mode processing is currently not applicable for palette
+ // mode in Inter frames. Clean-up the following code, once support is added
+ if (!frame_is_intra_only(cm) && is_palette_mode) return;
+ const int max_winner_mode_count = frame_is_intra_only(cm)
+ ? MAX_WINNER_MODE_COUNT_INTRA
+ : MAX_WINNER_MODE_COUNT_INTER;
assert(x->winner_mode_count >= 0 &&
- x->winner_mode_count <= MAX_WINNER_MODE_COUNT);
+ x->winner_mode_count <= max_winner_mode_count);
if (x->winner_mode_count) {
// Find the mode which has higher rd cost than this_rd
for (mode_idx = 0; mode_idx < x->winner_mode_count; mode_idx++)
if (winner_mode_stats[mode_idx].rd > this_rd) break;
- if (mode_idx == MAX_WINNER_MODE_COUNT) {
+ if (mode_idx == max_winner_mode_count) {
// No mode has higher rd cost than this_rd
return;
- } else if (mode_idx < MAX_WINNER_MODE_COUNT - 1) {
+ } else if (mode_idx < max_winner_mode_count - 1) {
// Create a slot for current mode and move others to the next slot
memmove(
&winner_mode_stats[mode_idx + 1], &winner_mode_stats[mode_idx],
- (MAX_WINNER_MODE_COUNT - mode_idx - 1) * sizeof(*winner_mode_stats));
+ (max_winner_mode_count - mode_idx - 1) * sizeof(*winner_mode_stats));
}
}
// Add a mode stat for winner mode processing
winner_mode_stats[mode_idx].mbmi = *mbmi;
winner_mode_stats[mode_idx].rd = this_rd;
+ winner_mode_stats[mode_idx].mode_index = mode_index;
+
+ // Update rd stats required for inter frame
+ if (!frame_is_intra_only(cm) && rd_cost && rd_cost_y && rd_cost_uv) {
+ const MACROBLOCKD *xd = &x->e_mbd;
+ const int skip_ctx = av1_get_skip_context(xd);
+ const int is_intra_mode = av1_mode_defs[mode_index].mode < INTRA_MODE_END;
+ const int skip = mbmi->skip && !is_intra_mode;
+
+ winner_mode_stats[mode_idx].rd_cost = *rd_cost;
+ if (txfm_search_done) {
+ winner_mode_stats[mode_idx].rate_y =
+ rd_cost_y->rate + x->skip_cost[skip_ctx][rd_cost->skip || skip];
+ winner_mode_stats[mode_idx].rate_uv = rd_cost_uv->rate;
+ }
+ }
+
if (color_map) {
// Store color_index_map for palette mode
const MACROBLOCKD *const xd = &x->e_mbd;
@@ -4599,7 +4626,7 @@
}
x->winner_mode_count =
- AOMMIN(x->winner_mode_count + 1, MAX_WINNER_MODE_COUNT);
+ AOMMIN(x->winner_mode_count + 1, max_winner_mode_count);
}
// Given the base colors as specified in centroids[], calculate the RD cost
@@ -4652,8 +4679,10 @@
tokenonly_rd_stats.rate -= tx_size_cost(x, bsize, mbmi->tx_size);
}
// Collect mode stats for multiwinner mode processing
- store_winner_mode_stats(x, mbmi, cpi->sf.enable_multiwinner_mode_process,
- color_map, bsize, this_rd);
+ const int txfm_search_done = 1;
+ store_winner_mode_stats(
+ &cpi->common, x, mbmi, NULL, NULL, NULL, THR_DC, color_map, bsize,
+ this_rd, cpi->sf.enable_multiwinner_mode_process, txfm_search_done);
if (this_rd < *best_rd) {
*best_rd = this_rd;
// Setting beat_best_rd flag because current mode rd is better than best_rd.
@@ -4917,8 +4946,10 @@
this_rd = RDCOST(x->rdmult, this_rate, tokenonly_rd_stats.dist);
// Collect mode stats for multiwinner mode processing
- store_winner_mode_stats(x, mbmi, cpi->sf.enable_multiwinner_mode_process,
- NULL, bsize, this_rd);
+ const int txfm_search_done = 1;
+ store_winner_mode_stats(
+ &cpi->common, x, mbmi, NULL, NULL, NULL, 0, NULL, bsize, this_rd,
+ cpi->sf.enable_multiwinner_mode_process, txfm_search_done);
if (this_rd < *best_rd) {
*best_rd = this_rd;
best_tx_size = mbmi->tx_size;
@@ -5298,10 +5329,13 @@
set_mode_eval_params(cpi, x, MODE_EVAL);
MB_MODE_INFO best_mbmi = *mbmi;
+ av1_zero(x->winner_mode_stats);
x->winner_mode_count = 0;
// Initialize best mode stats for winner mode processing
- store_winner_mode_stats(x, mbmi, cpi->sf.enable_multiwinner_mode_process,
- NULL, bsize, best_rd);
+ const int txfm_search_done = 1;
+ store_winner_mode_stats(
+ &cpi->common, x, mbmi, NULL, NULL, NULL, 0, NULL, bsize, best_rd,
+ cpi->sf.enable_multiwinner_mode_process, txfm_search_done);
/* Y Search for intra prediction mode */
for (int mode_idx = INTRA_MODE_START; mode_idx < INTRA_MODE_END; ++mode_idx) {
RD_STATS this_rd_stats;
@@ -5350,8 +5384,9 @@
intra_mode_info_cost_y(cpi, x, mbmi, bsize, bmode_costs[mbmi->mode]);
this_rd = RDCOST(x->rdmult, this_rate, this_distortion);
// Collect mode stats for multiwinner mode processing
- store_winner_mode_stats(x, mbmi, cpi->sf.enable_multiwinner_mode_process,
- NULL, bsize, this_rd);
+ store_winner_mode_stats(
+ &cpi->common, x, mbmi, NULL, NULL, NULL, 0, NULL, bsize, this_rd,
+ cpi->sf.enable_multiwinner_mode_process, txfm_search_done);
if (this_rd < best_rd) {
best_mbmi = *mbmi;
best_rd = this_rd;
@@ -11118,6 +11153,13 @@
}
}
if (skip) {
+ const THR_MODES mode_enum = get_prediction_mode_idx(
+ best_mbmi.mode, best_mbmi.ref_frame[0], best_mbmi.ref_frame[1]);
+ // Collect mode stats for multiwinner mode processing
+ store_winner_mode_stats(
+ &cpi->common, x, &best_mbmi, &best_rd_stats, &best_rd_stats_y,
+ &best_rd_stats_uv, mode_enum, NULL, bsize, best_rd,
+ cpi->sf.enable_multiwinner_mode_process, do_tx_search);
args->modelled_rd[this_mode][ref_mv_idx][refs[0]] =
args->modelled_rd[this_mode][i][refs[0]];
args->simple_rd[this_mode][ref_mv_idx][refs[0]] =
@@ -11278,6 +11320,13 @@
if (ret_val != INT64_MAX) {
int64_t tmp_rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
mode_info[ref_mv_idx].rd = tmp_rd;
+ const THR_MODES mode_enum = get_prediction_mode_idx(
+ mbmi->mode, mbmi->ref_frame[0], mbmi->ref_frame[1]);
+ // Collect mode stats for multiwinner mode processing
+ store_winner_mode_stats(&cpi->common, x, mbmi, rd_stats, rd_stats_y,
+ rd_stats_uv, mode_enum, NULL, bsize, tmp_rd,
+ cpi->sf.enable_multiwinner_mode_process,
+ do_tx_search);
if (tmp_rd < best_rd) {
best_rd_stats = *rd_stats;
best_rd_stats_y = *rd_stats_y;
@@ -11776,6 +11825,33 @@
}
}
+// Get winner mode stats of given mode index
+static AOM_INLINE MB_MODE_INFO *get_winner_mode_stats(
+ MACROBLOCK *x, MB_MODE_INFO *best_mbmode, RD_STATS *best_rd_cost,
+ int best_rate_y, int best_rate_uv, THR_MODES *best_mode_index,
+ RD_STATS **winner_rd_cost, int *winner_rate_y, int *winner_rate_uv,
+ THR_MODES *winner_mode_index, int enable_multiwinner_mode_process,
+ int mode_idx) {
+ MB_MODE_INFO *winner_mbmi;
+ if (enable_multiwinner_mode_process) {
+ assert(mode_idx >= 0 && mode_idx < x->winner_mode_count);
+ WinnerModeStats *winner_mode_stat = &x->winner_mode_stats[mode_idx];
+ winner_mbmi = &winner_mode_stat->mbmi;
+
+ *winner_rd_cost = &winner_mode_stat->rd_cost;
+ *winner_rate_y = winner_mode_stat->rate_y;
+ *winner_rate_uv = winner_mode_stat->rate_uv;
+ *winner_mode_index = winner_mode_stat->mode_index;
+ } else {
+ winner_mbmi = best_mbmode;
+ *winner_rd_cost = best_rd_cost;
+ *winner_rate_y = best_rate_y;
+ *winner_rate_uv = best_rate_uv;
+ *winner_mode_index = *best_mode_index;
+ }
+ return winner_mbmi;
+}
+
// speed feature: fast intra/inter transform type search
// Used for speed >= 2
// When this speed feature is on, in rd mode search, only DCT is used.
@@ -11783,24 +11859,47 @@
// transform types and get accurate rdcost.
static AOM_INLINE void refine_winner_mode_tx(
const AV1_COMP *cpi, MACROBLOCK *x, RD_STATS *rd_cost, BLOCK_SIZE bsize,
- PICK_MODE_CONTEXT *ctx, THR_MODES best_mode_index,
+ PICK_MODE_CONTEXT *ctx, THR_MODES *best_mode_index,
MB_MODE_INFO *best_mbmode, struct buf_2d yv12_mb[REF_FRAMES][MAX_MB_PLANE],
- int best_rate_y, int best_rate_uv, int *best_skip2) {
+ int best_rate_y, int best_rate_uv, int *best_skip2, int winner_mode_count) {
const AV1_COMMON *const cm = &cpi->common;
MACROBLOCKD *const xd = &x->e_mbd;
MB_MODE_INFO *const mbmi = xd->mi[0];
+ int64_t best_rd;
const int num_planes = av1_num_planes(cm);
- if (is_winner_mode_processing_enabled(cpi, mbmi, best_mbmode->mode)) {
- // Set params for winner mode evaluation
- set_mode_eval_params(cpi, x, WINNER_MODE_EVAL);
+ if (!is_winner_mode_processing_enabled(cpi, best_mbmode, best_mbmode->mode))
+ return;
- if (xd->lossless[mbmi->segment_id] == 0 && best_mode_index != THR_INVALID) {
+ // Set params for winner mode evaluation
+ set_mode_eval_params(cpi, x, WINNER_MODE_EVAL);
+
+ // No best mode identified so far
+ if (*best_mode_index == THR_INVALID) return;
+
+ best_rd = RDCOST(x->rdmult, rd_cost->rate, rd_cost->dist);
+ for (int mode_idx = 0; mode_idx < winner_mode_count; mode_idx++) {
+ RD_STATS *winner_rd_stats = NULL;
+ int winner_rate_y = 0, winner_rate_uv = 0;
+ THR_MODES winner_mode_index = 0;
+
+ // TODO(any): Combine best mode and multi-winner mode processing paths
+ // Get winner mode stats for current mode index
+ MB_MODE_INFO *winner_mbmi = get_winner_mode_stats(
+ x, best_mbmode, rd_cost, best_rate_y, best_rate_uv, best_mode_index,
+ &winner_rd_stats, &winner_rate_y, &winner_rate_uv, &winner_mode_index,
+ cpi->sf.enable_multiwinner_mode_process, mode_idx);
+
+ if (xd->lossless[winner_mbmi->segment_id] == 0 &&
+ winner_mode_index != THR_INVALID &&
+ is_winner_mode_processing_enabled(cpi, winner_mbmi,
+ winner_mbmi->mode)) {
+ RD_STATS rd_stats = *winner_rd_stats;
int skip_blk = 0;
RD_STATS rd_stats_y, rd_stats_uv;
const int skip_ctx = av1_get_skip_context(xd);
- *mbmi = *best_mbmode;
+ *mbmi = *winner_mbmi;
set_ref_ptrs(cm, xd, mbmi->ref_frame[0], mbmi->ref_frame[1]);
@@ -11855,18 +11954,20 @@
skip_blk = 0;
rd_stats_y.rate += x->skip_cost[skip_ctx][0];
}
-
- if (RDCOST(x->rdmult, best_rate_y + best_rate_uv, rd_cost->dist) >
- RDCOST(x->rdmult, rd_stats_y.rate + rd_stats_uv.rate,
- (rd_stats_y.dist + rd_stats_uv.dist))) {
- best_mbmode->tx_size = mbmi->tx_size;
- av1_copy(best_mbmode->inter_tx_size, mbmi->inter_tx_size);
+ int this_rate = rd_stats.rate + rd_stats_y.rate + rd_stats_uv.rate -
+ winner_rate_y - winner_rate_uv;
+ int64_t this_rd =
+ RDCOST(x->rdmult, this_rate, (rd_stats_y.dist + rd_stats_uv.dist));
+ if (best_rd > this_rd) {
+ *best_mbmode = *mbmi;
+ *best_mode_index = winner_mode_index;
av1_copy_array(ctx->blk_skip, x->blk_skip, ctx->num_4x4_blk);
av1_copy_array(ctx->tx_type_map, xd->tx_type_map, ctx->num_4x4_blk);
- rd_cost->rate +=
- (rd_stats_y.rate + rd_stats_uv.rate - best_rate_y - best_rate_uv);
+ rd_cost->rate = this_rate;
rd_cost->dist = rd_stats_y.dist + rd_stats_uv.dist;
- rd_cost->rdcost = RDCOST(x->rdmult, rd_cost->rate, rd_cost->dist);
+ rd_cost->sse = rd_stats_y.sse + rd_stats_uv.sse;
+ rd_cost->rdcost = this_rd;
+ best_rd = this_rd;
*best_skip2 = skip_blk;
}
}
@@ -13124,9 +13225,14 @@
if (ret_value != INT64_MAX) {
rd_stats.rdcost = RDCOST(x->rdmult, rd_stats.rate, rd_stats.dist);
+ const THR_MODES mode_enum = get_prediction_mode_idx(
+ mbmi->mode, mbmi->ref_frame[0], mbmi->ref_frame[1]);
+ // Collect mode stats for multiwinner mode processing
+ store_winner_mode_stats(
+ &cpi->common, x, mbmi, &rd_stats, &rd_stats_y, &rd_stats_uv,
+ mode_enum, NULL, bsize, rd_stats.rdcost,
+ cpi->sf.enable_multiwinner_mode_process, do_tx_search);
if (rd_stats.rdcost < search_state->best_rd) {
- const THR_MODES mode_enum = get_prediction_mode_idx(
- mbmi->mode, mbmi->ref_frame[0], mbmi->ref_frame[1]);
update_search_state(search_state, rd_cost, ctx, &rd_stats, &rd_stats_y,
&rd_stats_uv, mode_enum, x, do_tx_search);
}
@@ -13275,6 +13381,13 @@
find_last_single_ref_mode_idx(av1_default_mode_order);
int prune_cpd_using_sr_stats_ready = 0;
+ // Initialize best mode stats for winner mode processing
+ av1_zero(x->winner_mode_stats);
+ x->winner_mode_count = 0;
+ store_winner_mode_stats(&cpi->common, x, mbmi, NULL, NULL, NULL, THR_INVALID,
+ NULL, bsize, best_rd_so_far,
+ cpi->sf.enable_multiwinner_mode_process, 0);
+
// Here midx is just an interator index that should not be used by itself
// except to keep track of the number of modes searched. It should be used
// with av1_default_mode_order to get the enum that defines the mode, which
@@ -13499,6 +13612,11 @@
inter_modes_info_sort(inter_modes_info, inter_modes_info->rd_idx_pair_arr);
search_state.best_rd = best_rd_so_far;
search_state.best_mode_index = THR_INVALID;
+ // Initialize best mode stats for winner mode processing
+ x->winner_mode_count = 0;
+ store_winner_mode_stats(
+ &cpi->common, x, mbmi, NULL, NULL, NULL, THR_INVALID, NULL, bsize,
+ best_rd_so_far, cpi->sf.enable_multiwinner_mode_process, do_tx_search);
inter_modes_info->num =
inter_modes_info->num < cpi->sf.num_inter_modes_for_tx_search
? inter_modes_info->num
@@ -13544,13 +13662,19 @@
x->skip_cost[skip_ctx][mbmi->skip]);
}
rd_stats.rdcost = RDCOST(x->rdmult, rd_stats.rate, rd_stats.dist);
+ // TODO(chiyotsai@google.com): get_prediction_mode_idx gives incorrect
+ // output once we change the mode order. Fix this!
+ const THR_MODES mode_enum = get_prediction_mode_idx(
+ mbmi->mode, mbmi->ref_frame[0], mbmi->ref_frame[1]);
+
+ // Collect mode stats for multiwinner mode processing
+ const int txfm_search_done = 1;
+ store_winner_mode_stats(
+ &cpi->common, x, mbmi, &rd_stats, &rd_stats_y, &rd_stats_uv,
+ mode_enum, NULL, bsize, rd_stats.rdcost,
+ cpi->sf.enable_multiwinner_mode_process, txfm_search_done);
if (rd_stats.rdcost < search_state.best_rd) {
- // TODO(chiyotsai@google.com): get_prediction_mode_idx gives incorrect
- // output once we change the mode order. Fix this!
- const THR_MODES mode_enum = get_prediction_mode_idx(
- mbmi->mode, mbmi->ref_frame[0], mbmi->ref_frame[1]);
- const int txfm_search_done = 1;
update_search_state(&search_state, rd_cost, ctx, &rd_stats, &rd_stats_y,
&rd_stats_uv, mode_enum, x, txfm_search_done);
}
@@ -13628,8 +13752,15 @@
intra_rd_stats.rdcost = handle_intra_mode(
&search_state, cpi, x, bsize, intra_ref_frame_cost, ctx, 0,
&intra_rd_stats, &intra_rd_stats_y, &intra_rd_stats_uv);
+ // Collect mode stats for multiwinner mode processing
+ const int txfm_search_done = 1;
+ if (intra_rd_stats.rdcost != INT64_MAX) {
+ store_winner_mode_stats(
+ &cpi->common, x, mbmi, &intra_rd_stats, &intra_rd_stats_y,
+ &intra_rd_stats_uv, mode_enum, NULL, bsize, intra_rd_stats.rdcost,
+ cpi->sf.enable_multiwinner_mode_process, txfm_search_done);
+ }
if (intra_rd_stats.rdcost < search_state.best_rd) {
- const int txfm_search_done = 1;
update_search_state(&search_state, rd_cost, ctx, &intra_rd_stats,
&intra_rd_stats_y, &intra_rd_stats_uv, mode_enum, x,
txfm_search_done);
@@ -13639,11 +13770,13 @@
end_timing(cpi, handle_intra_mode_time);
#endif
+ int winner_mode_count =
+ cpi->sf.enable_multiwinner_mode_process ? x->winner_mode_count : 1;
// In effect only when fast tx search speed features are enabled.
- refine_winner_mode_tx(cpi, x, rd_cost, bsize, ctx,
- search_state.best_mode_index, &search_state.best_mbmode,
- yv12_mb, search_state.best_rate_y,
- search_state.best_rate_uv, &search_state.best_skip2);
+ refine_winner_mode_tx(
+ cpi, x, rd_cost, bsize, ctx, &search_state.best_mode_index,
+ &search_state.best_mbmode, yv12_mb, search_state.best_rate_y,
+ search_state.best_rate_uv, &search_state.best_skip2, winner_mode_count);
// Initialize default mode evaluation params
set_mode_eval_params(cpi, x, DEFAULT_EVAL);