Change comp_group index context and save sending comp_group
Extend context model for comp_group_idx.
Save sending comp_group_idx when masked_compound is not allowed.
Change-Id: Ia7ae53958c9e1c8fe07be4b14a425d9b8648082d
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index c669a4c..e4f1fec 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -1454,10 +1454,17 @@
// First write idx to indicate current compound inter prediction mode group
// Group A (0): jnt_comp, compound_average
// Group B (1): interintra, compound_segment, wedge
+ int masked_compound_used = is_any_masked_compound_used(bsize);
+ masked_compound_used = masked_compound_used && cm->allow_masked_compound;
+
if (has_second_ref(mbmi)) {
- const int ctx_comp_group_idx = get_comp_group_idx_context(xd);
- aom_write_symbol(w, mbmi->comp_group_idx,
- ec_ctx->comp_group_idx_cdf[ctx_comp_group_idx], 2);
+ if (masked_compound_used) {
+ assert(mbmi->comp_group_idx == 0);
+
+ const int ctx_comp_group_idx = get_comp_group_idx_context(xd);
+ aom_write_symbol(w, mbmi->comp_group_idx,
+ ec_ctx->comp_group_idx_cdf[ctx_comp_group_idx], 2);
+ }
if (mbmi->comp_group_idx == 0) {
if (mbmi->compound_idx)
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index ca14064..1ec0cca 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -5661,7 +5661,6 @@
}
}
-#if !CONFIG_JNT_COMP
static int get_interinter_compound_type_bits(BLOCK_SIZE bsize,
COMPOUND_TYPE comp_type) {
(void)bsize;
@@ -5672,7 +5671,6 @@
default: assert(0); return 0;
}
}
-#endif
typedef struct {
int eobs;
@@ -8142,12 +8140,16 @@
#if CONFIG_JNT_COMP
if (is_comp_pred) {
if (mbmi->compound_idx == 0) {
- mbmi->comp_group_idx = 0;
- const int comp_group_idx_ctx = get_comp_group_idx_context(xd);
- rd_stats->rate += x->comp_group_idx_cost[comp_group_idx_ctx][0];
+ int masked_compound_used = is_any_masked_compound_used(bsize);
+ masked_compound_used = masked_compound_used && cm->allow_masked_compound;
+
+ if (masked_compound_used) {
+ const int comp_group_idx_ctx = get_comp_group_idx_context(xd);
+ rd_stats->rate += x->comp_group_idx_cost[comp_group_idx_ctx][0];
+ }
const int comp_index_ctx = get_comp_index_context(cm, xd);
- rd_stats->rate += x->comp_idx_cost[comp_index_ctx][0];
+ rd_stats->rate += x->comp_idx_cost[comp_index_ctx][mbmi->compound_idx];
}
}
#endif // CONFIG_JNT_COMP
@@ -8316,26 +8318,27 @@
best_rd_cur = INT64_MAX;
mbmi->interinter_compound_type = cur_type;
#if CONFIG_JNT_COMP
- const int ctx_comp_group_idx = get_comp_group_idx_context(xd);
- if (cur_type == COMPOUND_AVERAGE) {
- mbmi->comp_group_idx = 0;
- rs2 = x->comp_group_idx_cost[ctx_comp_group_idx][0];
+ const int comp_group_idx_ctx = get_comp_group_idx_context(xd);
+ int masked_type_cost = 0;
+ if (masked_compound_used) {
+ if (cur_type == COMPOUND_AVERAGE) {
+ masked_type_cost += x->comp_group_idx_cost[comp_group_idx_ctx][0];
- const int comp_index_ctx = get_comp_index_context(cm, xd);
- rs2 += x->comp_idx_cost[comp_index_ctx][1];
- } else {
- mbmi->comp_group_idx = 1;
- rs2 = x->comp_group_idx_cost[ctx_comp_group_idx][1];
+ const int comp_index_ctx = get_comp_index_context(cm, xd);
+ masked_type_cost += x->comp_idx_cost[comp_index_ctx][1];
+ } else {
+ masked_type_cost += x->comp_group_idx_cost[comp_group_idx_ctx][1];
- int masked_type_cost = 0;
- if (masked_compound_used) {
- if (is_interinter_compound_used(COMPOUND_WEDGE, bsize))
- masked_type_cost +=
- x->compound_type_cost[bsize]
- [mbmi->interinter_compound_type - 1];
+ masked_type_cost +=
+ x->compound_type_cost[bsize][mbmi->interinter_compound_type - 1];
}
- rs2 += masked_type_cost;
+ } else {
+ const int comp_index_ctx = get_comp_index_context(cm, xd);
+ masked_type_cost += x->comp_idx_cost[comp_index_ctx][1];
}
+ rs2 = av1_cost_literal(get_interinter_compound_type_bits(
+ bsize, mbmi->interinter_compound_type)) +
+ masked_type_cost;
#else
int masked_type_cost = 0;
if (masked_compound_used) {