Rework transform block partition context model
This commit allows the partition context model to account for the
maximum transform block size of the coding block.
Change-Id: I22b91e85fff70faa974afd362ce327d3f2eda81d
diff --git a/av1/common/entropymode.c b/av1/common/entropymode.c
index 201bb16..cea5769 100644
--- a/av1/common/entropymode.c
+++ b/av1/common/entropymode.c
@@ -984,7 +984,7 @@
#if CONFIG_VAR_TX
static const aom_prob default_txfm_partition_probs[TXFM_PARTITION_CONTEXTS] = {
- 192, 128, 64, 192, 128, 64, 192, 128, 64,
+ 250, 231, 212, 241, 166, 66, 241, 230, 135, 243, 154, 64, 248, 161, 63, 128,
};
#endif
diff --git a/av1/common/enums.h b/av1/common/enums.h
index a684eed..e274d4b 100644
--- a/av1/common/enums.h
+++ b/av1/common/enums.h
@@ -408,7 +408,7 @@
#define REF_CONTEXTS 5
#if CONFIG_VAR_TX
-#define TXFM_PARTITION_CONTEXTS 9
+#define TXFM_PARTITION_CONTEXTS 16
typedef TX_SIZE TXFM_CONTEXT;
#endif
diff --git a/av1/common/onyxc_int.h b/av1/common/onyxc_int.h
index 7c12bd3..c0ed538 100644
--- a/av1/common/onyxc_int.h
+++ b/av1/common/onyxc_int.h
@@ -756,11 +756,30 @@
static INLINE int txfm_partition_context(TXFM_CONTEXT *above_ctx,
TXFM_CONTEXT *left_ctx,
- TX_SIZE tx_size) {
- int above = *above_ctx < tx_size;
- int left = *left_ctx < tx_size;
+ const BLOCK_SIZE bsize,
+ const TX_SIZE tx_size) {
+ const int above = *above_ctx < tx_size;
+ const int left = *left_ctx < tx_size;
+ TX_SIZE max_tx_size = max_txsize_lookup[bsize];
+ int category = 15;
- return (tx_size - TX_8X8) * 3 + above + left;
+ if (max_tx_size == TX_32X32) {
+ if (tx_size == TX_32X32)
+ category = 0;
+ else
+ category = 1;
+ } else if (max_tx_size == TX_16X16) {
+ if (tx_size == TX_16X16)
+ category = 2;
+ else
+ category = 3;
+ } else if (max_tx_size == TX_8X8) {
+ category = 4;
+ }
+
+ if (category == 15) return category;
+
+ return category * 3 + above + left;
}
#endif
diff --git a/av1/decoder/decodemv.c b/av1/decoder/decodemv.c
index 85f8111..77cea8a 100644
--- a/av1/decoder/decodemv.c
+++ b/av1/decoder/decodemv.c
@@ -301,7 +301,8 @@
int max_blocks_high = block_size_high[mbmi->sb_type];
int max_blocks_wide = block_size_wide[mbmi->sb_type];
int ctx = txfm_partition_context(xd->above_txfm_context + tx_col,
- xd->left_txfm_context + tx_row, tx_size);
+ xd->left_txfm_context + tx_row,
+ mbmi->sb_type, tx_size);
TX_SIZE(*const inter_tx_size)
[MAX_MIB_SIZE] =
(TX_SIZE(*)[MAX_MIB_SIZE]) & mbmi->inter_tx_size[tx_row][tx_col];
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index e232a87..4275098 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -360,7 +360,8 @@
const int max_blocks_wide = max_block_wide(xd, mbmi->sb_type, 0);
int ctx = txfm_partition_context(xd->above_txfm_context + tx_col,
- xd->left_txfm_context + tx_row, tx_size);
+ xd->left_txfm_context + tx_row,
+ mbmi->sb_type, tx_size);
if (blk_row >= max_blocks_high || blk_col >= max_blocks_wide) return;
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index 5ad334d..d1d6ecc 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -4987,7 +4987,7 @@
#if CONFIG_VAR_TX
static void update_txfm_count(MACROBLOCK *x, MACROBLOCKD *xd,
- FRAME_COUNTS *counts, TX_SIZE tx_size,
+ FRAME_COUNTS *counts, TX_SIZE tx_size, int depth,
int blk_row, int blk_col) {
MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
const int tx_row = blk_row >> 1;
@@ -4995,7 +4995,8 @@
const int max_blocks_high = max_block_high(xd, mbmi->sb_type, 0);
const int max_blocks_wide = max_block_wide(xd, mbmi->sb_type, 0);
int ctx = txfm_partition_context(xd->above_txfm_context + tx_col,
- xd->left_txfm_context + tx_row, tx_size);
+ xd->left_txfm_context + tx_row,
+ mbmi->sb_type, tx_size);
const TX_SIZE plane_tx_size = mbmi->inter_tx_size[tx_row][tx_col];
if (blk_row >= max_blocks_high || blk_col >= max_blocks_wide) return;
@@ -5023,8 +5024,8 @@
for (i = 0; i < 4; ++i) {
int offsetr = (i >> 1) * bh / 2;
int offsetc = (i & 0x01) * bh / 2;
- update_txfm_count(x, xd, counts, tx_size - 1, blk_row + offsetr,
- blk_col + offsetc);
+ update_txfm_count(x, xd, counts, tx_size - 1, depth + 1,
+ blk_row + offsetr, blk_col + offsetc);
}
}
}
@@ -5046,7 +5047,8 @@
for (idy = 0; idy < mi_height; idy += bh)
for (idx = 0; idx < mi_width; idx += bh)
- update_txfm_count(x, xd, td_counts, max_tx_size, idy, idx);
+ update_txfm_count(x, xd, td_counts, max_tx_size, mi_width != mi_height,
+ idy, idx);
}
static void set_txfm_context(MACROBLOCKD *xd, TX_SIZE tx_size, int blk_row,
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 46dbae4..b37714c 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -3064,8 +3064,9 @@
ENTROPY_CONTEXT *pta = ta + blk_col;
ENTROPY_CONTEXT *ptl = tl + blk_row;
int coeff_ctx, i;
- int ctx = txfm_partition_context(tx_above + (blk_col >> 1),
- tx_left + (blk_row >> 1), tx_size);
+ int ctx =
+ txfm_partition_context(tx_above + (blk_col >> 1),
+ tx_left + (blk_row >> 1), mbmi->sb_type, tx_size);
int64_t sum_dist = 0, sum_bsse = 0;
int64_t sum_rd = INT64_MAX;