Rework the txfm partition context to support cb4x4 mode
This commit reworks the transform block partition context update
to support cb4x4 mode in the recursive transform block partition.
It resolves the remaining enc/dec mismatch issue when both cb4x4
and var-tx are turned on.
Change-Id: I850d121204fe4c68e81488f1d2848c570d9d08b9
diff --git a/av1/common/onyxc_int.h b/av1/common/onyxc_int.h
index 7c9416d..f7659c4 100644
--- a/av1/common/onyxc_int.h
+++ b/av1/common/onyxc_int.h
@@ -793,8 +793,8 @@
static INLINE void txfm_partition_update(TXFM_CONTEXT *above_ctx,
TXFM_CONTEXT *left_ctx,
- TX_SIZE tx_size) {
- BLOCK_SIZE bsize = txsize_to_bsize[tx_size];
+ TX_SIZE tx_size, TX_SIZE txb_size) {
+ BLOCK_SIZE bsize = txsize_to_bsize[txb_size];
int bh = mi_size_high[bsize];
int bw = mi_size_wide[bsize];
uint8_t txw = tx_size_wide[tx_size];
diff --git a/av1/decoder/decodemv.c b/av1/decoder/decodemv.c
index 59a646d..3a5f25d 100644
--- a/av1/decoder/decodemv.c
+++ b/av1/decoder/decodemv.c
@@ -332,7 +332,7 @@
mbmi->min_tx_size = AOMMIN(mbmi->min_tx_size, get_min_tx_size(tx_size));
if (counts) ++counts->txfm_partition[ctx][0];
txfm_partition_update(xd->above_txfm_context + tx_col,
- xd->left_txfm_context + tx_row, tx_size);
+ xd->left_txfm_context + tx_row, tx_size, tx_size);
return;
}
@@ -350,11 +350,11 @@
inter_tx_size[0][0] = TX_4X4;
for (idy = 0; idy < tx_size_high_unit[tx_size] / 2; ++idy)
for (idx = 0; idx < tx_size_wide_unit[tx_size] / 2; ++idx)
- inter_tx_size[idy][idx] = tx_size;
+ inter_tx_size[idy][idx] = inter_tx_size[0][0];
mbmi->tx_size = TX_4X4;
mbmi->min_tx_size = get_min_tx_size(mbmi->tx_size);
txfm_partition_update(xd->above_txfm_context + tx_col,
- xd->left_txfm_context + tx_row, TX_4X4);
+ xd->left_txfm_context + tx_row, TX_4X4, tx_size);
return;
}
@@ -375,7 +375,7 @@
mbmi->min_tx_size = AOMMIN(mbmi->min_tx_size, get_min_tx_size(tx_size));
if (counts) ++counts->txfm_partition[ctx][0];
txfm_partition_update(xd->above_txfm_context + tx_col,
- xd->left_txfm_context + tx_row, tx_size);
+ xd->left_txfm_context + tx_row, tx_size, tx_size);
}
}
#endif
@@ -1952,8 +1952,15 @@
xd->above_txfm_context = cm->above_txfm_context + mi_col;
xd->left_txfm_context =
xd->left_txfm_context_buffer + (mi_row & MAX_MIB_MASK);
- if (bsize >= BLOCK_8X8 && cm->tx_mode == TX_MODE_SELECT && !mbmi->skip &&
- inter_block) {
+
+ if (cm->tx_mode == TX_MODE_SELECT &&
+#if CONFIG_CB4X4
+ (bsize >= BLOCK_8X8 ||
+ (bsize >= BLOCK_4X4 && inter_block && !mbmi->skip)) &&
+#else
+ bsize >= BLOCK_8X8 &&
+#endif
+ !mbmi->skip && inter_block) {
const TX_SIZE max_tx_size = max_txsize_rect_lookup[bsize];
const int bh = tx_size_high_unit[max_tx_size];
const int bw = tx_size_wide_unit[max_tx_size];
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index 27eb8df..094698b 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -383,14 +383,14 @@
if (depth == MAX_VARTX_DEPTH) {
txfm_partition_update(xd->above_txfm_context + tx_col,
- xd->left_txfm_context + tx_row, tx_size);
+ xd->left_txfm_context + tx_row, tx_size, tx_size);
return;
}
if (tx_size == mbmi->inter_tx_size[tx_row][tx_col]) {
aom_write(w, 0, cm->fc->txfm_partition_prob[ctx]);
txfm_partition_update(xd->above_txfm_context + tx_col,
- xd->left_txfm_context + tx_row, tx_size);
+ xd->left_txfm_context + tx_row, tx_size, tx_size);
} else {
const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
const int bsl = tx_size_wide_unit[sub_txs];
@@ -400,7 +400,7 @@
if (tx_size == TX_8X8) {
txfm_partition_update(xd->above_txfm_context + tx_col,
- xd->left_txfm_context + tx_row, TX_4X4);
+ xd->left_txfm_context + tx_row, TX_4X4, tx_size);
return;
}
@@ -1317,7 +1317,12 @@
if (!segfeature_active(seg, segment_id, SEG_LVL_REF_FRAME))
aom_write(w, is_inter, av1_get_intra_inter_prob(cm, xd));
- if (bsize >= BLOCK_8X8 && cm->tx_mode == TX_MODE_SELECT &&
+ if (cm->tx_mode == TX_MODE_SELECT &&
+#if CONFIG_CB4X4 && CONFIG_VAR_TX
+ (bsize >= BLOCK_8X8 || (bsize >= BLOCK_4X4 && is_inter && !skip)) &&
+#else
+ bsize >= BLOCK_8X8 &&
+#endif
#if CONFIG_SUPERTX
!supertx_enabled &&
#endif // CONFIG_SUPERTX
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index a32d002..1933345 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -5231,7 +5231,7 @@
++counts->txfm_partition[ctx][0];
mbmi->tx_size = tx_size;
txfm_partition_update(xd->above_txfm_context + tx_col,
- xd->left_txfm_context + tx_row, tx_size);
+ xd->left_txfm_context + tx_row, tx_size, tx_size);
} else {
const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
const int bs = tx_size_wide_unit[sub_txs];
@@ -5244,7 +5244,7 @@
mbmi->inter_tx_size[tx_row][tx_col] = TX_4X4;
mbmi->tx_size = TX_4X4;
txfm_partition_update(xd->above_txfm_context + tx_col,
- xd->left_txfm_context + tx_row, TX_4X4);
+ xd->left_txfm_context + tx_row, TX_4X4, tx_size);
return;
}
@@ -5292,7 +5292,7 @@
if (tx_size == plane_tx_size) {
mbmi->tx_size = tx_size;
txfm_partition_update(xd->above_txfm_context + tx_col,
- xd->left_txfm_context + tx_row, tx_size);
+ xd->left_txfm_context + tx_row, tx_size, tx_size);
} else {
const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
@@ -5303,7 +5303,7 @@
mbmi->inter_tx_size[tx_row][tx_col] = TX_4X4;
mbmi->tx_size = TX_4X4;
txfm_partition_update(xd->above_txfm_context + tx_col,
- xd->left_txfm_context + tx_row, TX_4X4);
+ xd->left_txfm_context + tx_row, TX_4X4, tx_size);
return;
}
@@ -5496,26 +5496,39 @@
#else
TX_SIZE tx_size = mbmi->tx_size;
#endif
- if (cm->tx_mode == TX_MODE_SELECT && mbmi->sb_type >= BLOCK_8X8 &&
+ if (cm->tx_mode == TX_MODE_SELECT &&
+#if CONFIG_CB4X4 && CONFIG_VAR_TX
+ (mbmi->sb_type >= BLOCK_8X8 ||
+ (mbmi->sb_type >= BLOCK_4X4 && is_inter &&
+ !(mbmi->skip || seg_skip))) &&
+#else
+ mbmi->sb_type >= BLOCK_8X8 &&
+#endif
!(is_inter && (mbmi->skip || seg_skip))) {
+#if CONFIG_VAR_TX
+ if (is_inter) {
+ tx_partition_count_update(cm, x, bsize, mi_row, mi_col, td->counts);
+ } else {
+ const int tx_size_ctx = get_tx_size_context(xd);
+ const int tx_size_cat = is_inter ? inter_tx_size_cat_lookup[bsize]
+ : intra_tx_size_cat_lookup[bsize];
+ const TX_SIZE coded_tx_size = txsize_sqr_up_map[tx_size];
+ const int depth = tx_size_to_depth(coded_tx_size);
+ ++td->counts->tx_size[tx_size_cat][tx_size_ctx][depth];
+ if (tx_size != max_txsize_lookup[bsize]) ++x->txb_split_count;
+ }
+#else
const int tx_size_ctx = get_tx_size_context(xd);
const int tx_size_cat = is_inter ? inter_tx_size_cat_lookup[bsize]
: intra_tx_size_cat_lookup[bsize];
const TX_SIZE coded_tx_size = txsize_sqr_up_map[tx_size];
const int depth = tx_size_to_depth(coded_tx_size);
+
+ ++td->counts->tx_size[tx_size_cat][tx_size_ctx][depth];
+#endif
#if CONFIG_EXT_TX && CONFIG_RECT_TX
assert(IMPLIES(is_rect_tx(tx_size), is_rect_tx_allowed(xd, mbmi)));
#endif // CONFIG_EXT_TX && CONFIG_RECT_TX
-#if CONFIG_VAR_TX
- if (is_inter) {
- tx_partition_count_update(cm, x, bsize, mi_row, mi_col, td->counts);
- } else {
- ++td->counts->tx_size[tx_size_cat][tx_size_ctx][depth];
- if (tx_size != max_txsize_lookup[bsize]) ++x->txb_split_count;
- }
-#else
- ++td->counts->tx_size[tx_size_cat][tx_size_ctx][depth];
-#endif
} else {
int i, j;
TX_SIZE intra_tx_size;
@@ -5577,8 +5590,13 @@
}
#if CONFIG_VAR_TX
- if (cm->tx_mode == TX_MODE_SELECT && mbmi->sb_type >= BLOCK_8X8 && is_inter &&
- !(mbmi->skip || seg_skip)) {
+ if (cm->tx_mode == TX_MODE_SELECT &&
+#if CONFIG_CB4X4
+ mbmi->sb_type >= BLOCK_4X4 &&
+#else
+ mbmi->sb_type >= BLOCK_8X8 &&
+#endif
+ is_inter && !(mbmi->skip || seg_skip)) {
if (dry_run) tx_partition_set_contexts(cm, xd, bsize, mi_row, mi_col);
} else {
TX_SIZE tx_size = mbmi->tx_size;
diff --git a/av1/encoder/encodemb.c b/av1/encoder/encodemb.c
index 25c2de8..a869f82 100644
--- a/av1/encoder/encodemb.c
+++ b/av1/encoder/encodemb.c
@@ -860,7 +860,7 @@
int idx, idy;
int block = 0;
int step = tx_size_wide_unit[max_tx_size] * tx_size_high_unit[max_tx_size];
- av1_get_entropy_contexts(bsize, TX_4X4, pd, ctx.ta[plane], ctx.tl[plane]);
+ av1_get_entropy_contexts(bsize, 0, pd, ctx.ta[plane], ctx.tl[plane]);
#else
const struct macroblockd_plane *const pd = &xd->plane[plane];
const TX_SIZE tx_size = plane ? get_uv_tx_size(mbmi, pd) : mbmi->tx_size;
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index eb9a951..f9ab467 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -3407,7 +3407,7 @@
for (i = 0; i < tx_size_wide_unit[tx_size]; ++i) pta[i] = !(tmp_eob == 0);
for (i = 0; i < tx_size_high_unit[tx_size]; ++i) ptl[i] = !(tmp_eob == 0);
txfm_partition_update(tx_above + (blk_col >> 1), tx_left + (blk_row >> 1),
- tx_size);
+ tx_size, tx_size);
inter_tx_size[0][0] = tx_size;
for (idy = 0; idy < tx_size_high_unit[tx_size] / 2; ++idy)
for (idx = 0; idx < tx_size_wide_unit[tx_size] / 2; ++idx)