Rework transform block partition context model

This commit allows the partition context model to account for the
maximum transform block size of the coding block.

Change-Id: I22b91e85fff70faa974afd362ce327d3f2eda81d
diff --git a/av1/common/entropymode.c b/av1/common/entropymode.c
index 201bb16..cea5769 100644
--- a/av1/common/entropymode.c
+++ b/av1/common/entropymode.c
@@ -984,7 +984,7 @@
 
 #if CONFIG_VAR_TX
 static const aom_prob default_txfm_partition_probs[TXFM_PARTITION_CONTEXTS] = {
-  192, 128, 64, 192, 128, 64, 192, 128, 64,
+  250, 231, 212, 241, 166, 66, 241, 230, 135, 243, 154, 64, 248, 161, 63, 128,
 };
 #endif
 
diff --git a/av1/common/enums.h b/av1/common/enums.h
index a684eed..e274d4b 100644
--- a/av1/common/enums.h
+++ b/av1/common/enums.h
@@ -408,7 +408,7 @@
 #define REF_CONTEXTS 5
 
 #if CONFIG_VAR_TX
-#define TXFM_PARTITION_CONTEXTS 9
+#define TXFM_PARTITION_CONTEXTS 16
 typedef TX_SIZE TXFM_CONTEXT;
 #endif
 
diff --git a/av1/common/onyxc_int.h b/av1/common/onyxc_int.h
index 7c12bd3..c0ed538 100644
--- a/av1/common/onyxc_int.h
+++ b/av1/common/onyxc_int.h
@@ -756,11 +756,30 @@
 
 static INLINE int txfm_partition_context(TXFM_CONTEXT *above_ctx,
                                          TXFM_CONTEXT *left_ctx,
-                                         TX_SIZE tx_size) {
-  int above = *above_ctx < tx_size;
-  int left = *left_ctx < tx_size;
+                                         const BLOCK_SIZE bsize,
+                                         const TX_SIZE tx_size) {
+  const int above = *above_ctx < tx_size;
+  const int left = *left_ctx < tx_size;
+  TX_SIZE max_tx_size = max_txsize_lookup[bsize];
+  int category = 15;
 
-  return (tx_size - TX_8X8) * 3 + above + left;
+  if (max_tx_size == TX_32X32) {
+    if (tx_size == TX_32X32)
+      category = 0;
+    else
+      category = 1;
+  } else if (max_tx_size == TX_16X16) {
+    if (tx_size == TX_16X16)
+      category = 2;
+    else
+      category = 3;
+  } else if (max_tx_size == TX_8X8) {
+    category = 4;
+  }
+
+  if (category == 15) return category;
+
+  return category * 3 + above + left;
 }
 #endif
 
diff --git a/av1/decoder/decodemv.c b/av1/decoder/decodemv.c
index 85f8111..77cea8a 100644
--- a/av1/decoder/decodemv.c
+++ b/av1/decoder/decodemv.c
@@ -301,7 +301,8 @@
   int max_blocks_high = block_size_high[mbmi->sb_type];
   int max_blocks_wide = block_size_wide[mbmi->sb_type];
   int ctx = txfm_partition_context(xd->above_txfm_context + tx_col,
-                                   xd->left_txfm_context + tx_row, tx_size);
+                                   xd->left_txfm_context + tx_row,
+                                   mbmi->sb_type, tx_size);
   TX_SIZE(*const inter_tx_size)
   [MAX_MIB_SIZE] =
       (TX_SIZE(*)[MAX_MIB_SIZE]) & mbmi->inter_tx_size[tx_row][tx_col];
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index e232a87..4275098 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -360,7 +360,8 @@
   const int max_blocks_wide = max_block_wide(xd, mbmi->sb_type, 0);
 
   int ctx = txfm_partition_context(xd->above_txfm_context + tx_col,
-                                   xd->left_txfm_context + tx_row, tx_size);
+                                   xd->left_txfm_context + tx_row,
+                                   mbmi->sb_type, tx_size);
 
   if (blk_row >= max_blocks_high || blk_col >= max_blocks_wide) return;
 
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index 5ad334d..d1d6ecc 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -4987,7 +4987,7 @@
 
 #if CONFIG_VAR_TX
 static void update_txfm_count(MACROBLOCK *x, MACROBLOCKD *xd,
-                              FRAME_COUNTS *counts, TX_SIZE tx_size,
+                              FRAME_COUNTS *counts, TX_SIZE tx_size, int depth,
                               int blk_row, int blk_col) {
   MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
   const int tx_row = blk_row >> 1;
@@ -4995,7 +4995,8 @@
   const int max_blocks_high = max_block_high(xd, mbmi->sb_type, 0);
   const int max_blocks_wide = max_block_wide(xd, mbmi->sb_type, 0);
   int ctx = txfm_partition_context(xd->above_txfm_context + tx_col,
-                                   xd->left_txfm_context + tx_row, tx_size);
+                                   xd->left_txfm_context + tx_row,
+                                   mbmi->sb_type, tx_size);
   const TX_SIZE plane_tx_size = mbmi->inter_tx_size[tx_row][tx_col];
 
   if (blk_row >= max_blocks_high || blk_col >= max_blocks_wide) return;
@@ -5023,8 +5024,8 @@
     for (i = 0; i < 4; ++i) {
       int offsetr = (i >> 1) * bh / 2;
       int offsetc = (i & 0x01) * bh / 2;
-      update_txfm_count(x, xd, counts, tx_size - 1, blk_row + offsetr,
-                        blk_col + offsetc);
+      update_txfm_count(x, xd, counts, tx_size - 1, depth + 1,
+                        blk_row + offsetr, blk_col + offsetc);
     }
   }
 }
@@ -5046,7 +5047,8 @@
 
   for (idy = 0; idy < mi_height; idy += bh)
     for (idx = 0; idx < mi_width; idx += bh)
-      update_txfm_count(x, xd, td_counts, max_tx_size, idy, idx);
+      update_txfm_count(x, xd, td_counts, max_tx_size, mi_width != mi_height,
+                        idy, idx);
 }
 
 static void set_txfm_context(MACROBLOCKD *xd, TX_SIZE tx_size, int blk_row,
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 46dbae4..b37714c 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -3064,8 +3064,9 @@
   ENTROPY_CONTEXT *pta = ta + blk_col;
   ENTROPY_CONTEXT *ptl = tl + blk_row;
   int coeff_ctx, i;
-  int ctx = txfm_partition_context(tx_above + (blk_col >> 1),
-                                   tx_left + (blk_row >> 1), tx_size);
+  int ctx =
+      txfm_partition_context(tx_above + (blk_col >> 1),
+                             tx_left + (blk_row >> 1), mbmi->sb_type, tx_size);
 
   int64_t sum_dist = 0, sum_bsse = 0;
   int64_t sum_rd = INT64_MAX;