JNT_COMP: change COMPOUND_AVERAGE in cdf
Remove COMPOUND_AVERAGE from compound_type_cdfs since it is now grouped
to compound_idx. However, COMPOUND_AVERAGE is still used elsewhere.
Change-Id: Ie0d460aabf9252e80eb4130cfef9aaf0efc3969d
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index 4f05d99..c669a4c 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -70,7 +70,11 @@
}
static struct av1_token interintra_mode_encodings[INTERINTRA_MODES];
+#if CONFIG_JNT_COMP
+static struct av1_token compound_type_encodings[COMPOUND_TYPES - 1];
+#else
static struct av1_token compound_type_encodings[COMPOUND_TYPES];
+#endif // CONFIG_JNT_COMP
#if CONFIG_LOOP_RESTORATION
static void loop_restoration_write_sb_coeffs(const AV1_COMMON *const cm,
MACROBLOCKD *xd,
@@ -1469,11 +1473,11 @@
if (cpi->common.reference_mode != SINGLE_REFERENCE &&
is_inter_compound_mode(mbmi->mode) &&
mbmi->motion_mode == SIMPLE_TRANSLATION &&
- is_any_masked_compound_used(bsize) && cm->allow_masked_compound &&
- mbmi->comp_group_idx) {
+ is_any_masked_compound_used(bsize) && cm->allow_masked_compound) {
if (is_interinter_compound_used(COMPOUND_WEDGE, bsize))
- aom_write_symbol(w, mbmi->interinter_compound_type,
- ec_ctx->compound_type_cdf[bsize], COMPOUND_TYPES);
+ aom_write_symbol(w, mbmi->interinter_compound_type - 1,
+ ec_ctx->compound_type_cdf[bsize],
+ COMPOUND_TYPES - 1);
if (is_interinter_compound_used(COMPOUND_WEDGE, bsize) &&
mbmi->interinter_compound_type == COMPOUND_WEDGE) {
diff --git a/av1/encoder/block.h b/av1/encoder/block.h
index 8de0a78..ec99b38 100644
--- a/av1/encoder/block.h
+++ b/av1/encoder/block.h
@@ -271,7 +271,11 @@
int drl_mode_cost0[DRL_MODE_CONTEXTS][2];
int inter_compound_mode_cost[INTER_MODE_CONTEXTS][INTER_COMPOUND_MODES];
+#if CONFIG_JNT_COMP
+ int compound_type_cost[BLOCK_SIZES_ALL][COMPOUND_TYPES - 1];
+#else
int compound_type_cost[BLOCK_SIZES_ALL][COMPOUND_TYPES];
+#endif // CONFIG_JNT_COMP
int interintra_cost[BLOCK_SIZE_GROUPS][2];
int wedge_interintra_cost[BLOCK_SIZES_ALL][2];
int interintra_mode_cost[BLOCK_SIZE_GROUPS][INTERINTRA_MODES];
@@ -347,6 +351,7 @@
#endif // CONFIG_DIST_8X8
#if CONFIG_JNT_COMP
int comp_idx_cost[COMP_INDEX_CONTEXTS][2];
+ int comp_group_idx_cost[COMP_GROUP_IDX_CONTEXTS][2];
#endif // CONFIG_JNT_COMP
};
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index be77347..f3575b1 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -1271,14 +1271,24 @@
mbmi->compound_idx, 2);
}
}
-#endif // CONFIG_JNT_COMP
if (cm->reference_mode != SINGLE_REFERENCE &&
- is_inter_compound_mode(mbmi->mode)
-#if CONFIG_JNT_COMP
- && mbmi->comp_group_idx
-#endif // CONFIG_JNT_COMP
- && mbmi->motion_mode == SIMPLE_TRANSLATION) {
+ is_inter_compound_mode(mbmi->mode) && mbmi->comp_group_idx &&
+ mbmi->motion_mode == SIMPLE_TRANSLATION) {
+ if (is_interinter_compound_used(COMPOUND_WEDGE, bsize)) {
+ counts->compound_interinter[bsize]
+ [mbmi->interinter_compound_type - 1]++;
+ if (allow_update_cdf)
+ update_cdf(fc->compound_type_cdf[bsize],
+ mbmi->interinter_compound_type - 1,
+ COMPOUND_TYPES - 1);
+ }
+ }
+#else // CONFIG_JNT_COMP
+
+ if (cm->reference_mode != SINGLE_REFERENCE &&
+ is_inter_compound_mode(mbmi->mode) &&
+ mbmi->motion_mode == SIMPLE_TRANSLATION) {
if (is_interinter_compound_used(COMPOUND_WEDGE, bsize)) {
counts
->compound_interinter[bsize][mbmi->interinter_compound_type]++;
@@ -1287,6 +1297,7 @@
mbmi->interinter_compound_type, COMPOUND_TYPES);
}
}
+#endif // CONFIG_JNT_COMP
}
}
diff --git a/av1/encoder/rd.c b/av1/encoder/rd.c
index 8abee95..3bd3b42 100644
--- a/av1/encoder/rd.c
+++ b/av1/encoder/rd.c
@@ -284,6 +284,10 @@
av1_cost_tokens_from_cdf(x->comp_idx_cost[i], fc->compound_index_cdf[i],
NULL);
}
+ for (i = 0; i < COMP_GROUP_IDX_CONTEXTS; ++i) {
+ av1_cost_tokens_from_cdf(x->comp_group_idx_cost[i],
+ fc->comp_group_idx_cdf[i], NULL);
+ }
#endif // CONFIG_JNT_COMP
}
}
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index d7e722e..ca14064 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -5661,6 +5661,7 @@
}
}
+#if !CONFIG_JNT_COMP
static int get_interinter_compound_type_bits(BLOCK_SIZE bsize,
COMPOUND_TYPE comp_type) {
(void)bsize;
@@ -5671,6 +5672,7 @@
default: assert(0); return 0;
}
}
+#endif
typedef struct {
int eobs;
@@ -8139,8 +8141,14 @@
#if CONFIG_JNT_COMP
if (is_comp_pred) {
- const int comp_index_ctx = get_comp_index_context(cm, xd);
- rd_stats->rate += x->comp_idx_cost[comp_index_ctx][mbmi->compound_idx];
+ if (mbmi->compound_idx == 0) {
+ mbmi->comp_group_idx = 0;
+ const int comp_group_idx_ctx = get_comp_group_idx_context(xd);
+ rd_stats->rate += x->comp_group_idx_cost[comp_group_idx_ctx][0];
+
+ const int comp_index_ctx = get_comp_index_context(cm, xd);
+ rd_stats->rate += x->comp_idx_cost[comp_index_ctx][0];
+ }
}
#endif // CONFIG_JNT_COMP
@@ -8307,6 +8315,28 @@
tmp_rate_mv = rate_mv;
best_rd_cur = INT64_MAX;
mbmi->interinter_compound_type = cur_type;
+#if CONFIG_JNT_COMP
+ const int ctx_comp_group_idx = get_comp_group_idx_context(xd);
+ if (cur_type == COMPOUND_AVERAGE) {
+ mbmi->comp_group_idx = 0;
+ rs2 = x->comp_group_idx_cost[ctx_comp_group_idx][0];
+
+ const int comp_index_ctx = get_comp_index_context(cm, xd);
+ rs2 += x->comp_idx_cost[comp_index_ctx][1];
+ } else {
+ mbmi->comp_group_idx = 1;
+ rs2 = x->comp_group_idx_cost[ctx_comp_group_idx][1];
+
+ int masked_type_cost = 0;
+ if (masked_compound_used) {
+ if (is_interinter_compound_used(COMPOUND_WEDGE, bsize))
+ masked_type_cost +=
+ x->compound_type_cost[bsize]
+ [mbmi->interinter_compound_type - 1];
+ }
+ rs2 += masked_type_cost;
+ }
+#else
int masked_type_cost = 0;
if (masked_compound_used) {
if (!is_interinter_compound_used(COMPOUND_WEDGE, bsize))
@@ -8318,6 +8348,7 @@
rs2 = av1_cost_literal(get_interinter_compound_type_bits(
bsize, mbmi->interinter_compound_type)) +
masked_type_cost;
+#endif // CONFIG_JNT_COMP
switch (cur_type) {
case COMPOUND_AVERAGE: