Enable entropy coding of recursive transform block partition

This commit enables the entropy coding of the recursive transform
block partition syntax.

Change-Id: I0c2509fb7b9822d12a721f9ebf9327fac83c777e
diff --git a/vp10/common/alloccommon.c b/vp10/common/alloccommon.c
index 9ca86e5..364afde 100644
--- a/vp10/common/alloccommon.c
+++ b/vp10/common/alloccommon.c
@@ -97,6 +97,10 @@
   cm->above_context = NULL;
   vpx_free(cm->above_seg_context);
   cm->above_seg_context = NULL;
+#if CONFIG_VAR_TX
+  vpx_free(cm->above_txfm_context);
+  cm->above_txfm_context = NULL;
+#endif
 }
 
 int vp10_alloc_context_buffers(VP10_COMMON *cm, int width, int height) {
@@ -128,6 +132,14 @@
     cm->above_seg_context = (PARTITION_CONTEXT *)vpx_calloc(
         mi_cols_aligned_to_sb(cm->mi_cols), sizeof(*cm->above_seg_context));
     if (!cm->above_seg_context) goto fail;
+
+#if CONFIG_VAR_TX
+    vpx_free(cm->above_txfm_context);
+    cm->above_txfm_context = (TXFM_CONTEXT *)vpx_calloc(
+        mi_cols_aligned_to_sb(cm->mi_cols), sizeof(*cm->above_txfm_context));
+    if (!cm->above_txfm_context) goto fail;
+#endif
+
     cm->above_context_alloc_cols = cm->mi_cols;
   }
 
diff --git a/vp10/common/blockd.h b/vp10/common/blockd.h
index 4bf06d1..6c01818 100644
--- a/vp10/common/blockd.h
+++ b/vp10/common/blockd.h
@@ -216,6 +216,12 @@
   PARTITION_CONTEXT *above_seg_context;
   PARTITION_CONTEXT left_seg_context[8];
 
+#if CONFIG_VAR_TX
+  TXFM_CONTEXT *above_txfm_context;
+  TXFM_CONTEXT *left_txfm_context;
+  TXFM_CONTEXT left_txfm_context_buffer[8];
+#endif
+
 #if CONFIG_VP9_HIGHBITDEPTH
   /* Bit depth: 8, 10, 12 */
   int bd;
diff --git a/vp10/common/entropymode.c b/vp10/common/entropymode.c
index 59e6df8..af15b2d 100644
--- a/vp10/common/entropymode.c
+++ b/vp10/common/entropymode.c
@@ -742,6 +742,12 @@
   ct_8x8p[0][1] = tx_count_8x8p[TX_8X8];
 }
 
+#if CONFIG_VAR_TX
+static const vpx_prob default_txfm_partition_probs[TXFM_PARTITION_CONTEXTS] = {
+    192, 128, 64, 192, 128, 64, 192, 128, 64,
+};
+#endif
+
 static const vpx_prob default_skip_probs[SKIP_CONTEXTS] = {
   192, 128, 64
 };
@@ -959,6 +965,9 @@
   vp10_copy(fc->comp_ref_prob, default_comp_ref_p);
   vp10_copy(fc->single_ref_prob, default_single_ref_p);
   fc->tx_probs = default_tx_probs;
+#if CONFIG_VAR_TX
+  vp10_copy(fc->txfm_partition_prob, default_txfm_partition_probs);
+#endif
   vp10_copy(fc->skip_probs, default_skip_probs);
   vp10_copy(fc->inter_mode_probs, default_inter_mode_probs);
 #if CONFIG_EXT_TX
@@ -1054,6 +1063,14 @@
     }
   }
 
+#if CONFIG_VAR_TX
+  if (cm->tx_mode == TX_MODE_SELECT)
+    for (i = 0; i < TXFM_PARTITION_CONTEXTS; ++i)
+      fc->txfm_partition_prob[i] =
+          mode_mv_merge_probs(pre_fc->txfm_partition_prob[i],
+                              counts->txfm_partition[i]);
+#endif
+
   for (i = 0; i < SKIP_CONTEXTS; ++i)
     fc->skip_probs[i] = mode_mv_merge_probs(
         pre_fc->skip_probs[i], counts->skip[i]);
diff --git a/vp10/common/entropymode.h b/vp10/common/entropymode.h
index b53b4e1..cb7807d 100644
--- a/vp10/common/entropymode.h
+++ b/vp10/common/entropymode.h
@@ -67,6 +67,9 @@
   vpx_prob single_ref_prob[REF_CONTEXTS][2];
   vpx_prob comp_ref_prob[REF_CONTEXTS];
   struct tx_probs tx_probs;
+#if CONFIG_VAR_TX
+  vpx_prob txfm_partition_prob[TXFM_PARTITION_CONTEXTS];
+#endif
   vpx_prob skip_probs[SKIP_CONTEXTS];
   nmv_context nmvc;
 #if CONFIG_EXT_TX
@@ -96,6 +99,9 @@
   unsigned int single_ref[REF_CONTEXTS][2][2];
   unsigned int comp_ref[REF_CONTEXTS][2];
   struct tx_counts tx;
+#if CONFIG_VAR_TX
+  unsigned int txfm_partition[TXFM_PARTITION_CONTEXTS][2];
+#endif
   unsigned int skip[SKIP_CONTEXTS][2];
   nmv_context_counts mv;
 #if CONFIG_EXT_TX
diff --git a/vp10/common/enums.h b/vp10/common/enums.h
index 33c5eaf..53db356 100644
--- a/vp10/common/enums.h
+++ b/vp10/common/enums.h
@@ -187,6 +187,11 @@
 #define COMP_INTER_CONTEXTS 5
 #define REF_CONTEXTS 5
 
+#if CONFIG_VAR_TX
+#define TXFM_PARTITION_CONTEXTS 9
+typedef TX_SIZE TXFM_CONTEXT;
+#endif
+
 #ifdef __cplusplus
 }  // extern "C"
 #endif
diff --git a/vp10/common/onyxc_int.h b/vp10/common/onyxc_int.h
index 6814133..4178fe7 100644
--- a/vp10/common/onyxc_int.h
+++ b/vp10/common/onyxc_int.h
@@ -301,6 +301,9 @@
 
   PARTITION_CONTEXT *above_seg_context;
   ENTROPY_CONTEXT *above_context;
+#if CONFIG_VAR_TX
+  TXFM_CONTEXT *above_txfm_context;
+#endif
   int above_context_alloc_cols;
 
   // scratch memory for intraonly/keyframe forward updates from default tables
@@ -397,6 +400,9 @@
   }
 
   xd->above_seg_context = cm->above_seg_context;
+#if CONFIG_VAR_TX
+  xd->above_txfm_context = cm->above_txfm_context;
+#endif
   xd->mi_stride = cm->mi_stride;
   xd->error_info = &cm->error;
 }
@@ -489,6 +495,28 @@
   return (left * 2 + above) + bsl * PARTITION_PLOFFSET;
 }
 
+#if CONFIG_VAR_TX
+static INLINE void txfm_partition_update(TXFM_CONTEXT *above_ctx,
+                                         TXFM_CONTEXT *left_ctx,
+                                         TX_SIZE tx_size) {
+  BLOCK_SIZE bsize = txsize_to_bsize[tx_size];
+  int bs = num_8x8_blocks_high_lookup[bsize];
+  int i;
+  for (i = 0; i < bs; ++i) {
+    above_ctx[i] = tx_size;
+    left_ctx[i] = tx_size;
+  }
+}
+
+static INLINE int txfm_partition_context(TXFM_CONTEXT *above_ctx,
+                                         TXFM_CONTEXT *left_ctx,
+                                         TX_SIZE tx_size) {
+  int above = *above_ctx < tx_size;
+  int left = *left_ctx < tx_size;
+  return (tx_size - 1) * 3 + above + left;
+}
+#endif
+
 #ifdef __cplusplus
 }  // extern "C"
 #endif
diff --git a/vp10/common/thread_common.c b/vp10/common/thread_common.c
index 7508195..6395e96 100644
--- a/vp10/common/thread_common.c
+++ b/vp10/common/thread_common.c
@@ -401,6 +401,12 @@
   for (i = 0; i < TX_SIZES; i++)
     cm->counts.tx.tx_totals[i] += counts->tx.tx_totals[i];
 
+#if CONFIG_VAR_TX
+  for (i = 0; i < TXFM_PARTITION_CONTEXTS; ++i)
+    for (j = 0; j < 2; ++j)
+      cm->counts.txfm_partition[i][j] += counts->txfm_partition[i][j];
+#endif
+
   for (i = 0; i < SKIP_CONTEXTS; i++)
     for (j = 0; j < 2; j++)
       cm->counts.skip[i][j] += counts->skip[i][j];
diff --git a/vp10/decoder/decodeframe.c b/vp10/decoder/decodeframe.c
index eaa31f1..f61ac2a 100644
--- a/vp10/decoder/decodeframe.c
+++ b/vp10/decoder/decodeframe.c
@@ -1619,6 +1619,11 @@
   memset(cm->above_seg_context, 0,
          sizeof(*cm->above_seg_context) * aligned_cols);
 
+#if CONFIG_VAR_TX
+  memset(cm->above_txfm_context, 0,
+         sizeof(*cm->above_txfm_context) * aligned_cols);
+#endif
+
   get_tile_buffers(pbi, data, data_end, tile_cols, tile_rows, tile_buffers);
 
   if (pbi->tile_data == NULL ||
@@ -1665,6 +1670,9 @@
         vp10_tile_set_col(&tile, tile_data->cm, col);
         vp10_zero(tile_data->xd.left_context);
         vp10_zero(tile_data->xd.left_seg_context);
+#if CONFIG_VAR_TX
+        vp10_zero(tile_data->xd.left_txfm_context_buffer);
+#endif
         for (mi_col = tile.mi_col_start; mi_col < tile.mi_col_end;
              mi_col += MI_BLOCK_SIZE) {
           decode_partition(pbi, &tile_data->xd, mi_row,
@@ -1738,6 +1746,9 @@
        mi_row += MI_BLOCK_SIZE) {
     vp10_zero(tile_data->xd.left_context);
     vp10_zero(tile_data->xd.left_seg_context);
+#if CONFIG_VAR_TX
+    vp10_zero(tile_data->xd.left_txfm_context_buffer);
+#endif
     for (mi_col = tile->mi_col_start; mi_col < tile->mi_col_end;
          mi_col += MI_BLOCK_SIZE) {
       decode_partition(tile_data->pbi, &tile_data->xd,
@@ -1815,7 +1826,10 @@
          sizeof(*cm->above_context) * MAX_MB_PLANE * 2 * aligned_mi_cols);
   memset(cm->above_seg_context, 0,
          sizeof(*cm->above_seg_context) * aligned_mi_cols);
-
+#if CONFIG_VAR_TX
+  memset(cm->above_txfm_context, 0,
+         sizeof(*cm->above_txfm_context) * aligned_mi_cols);
+#endif
   // Load tile data into tile_buffers
   get_tile_buffers(pbi, data, data_end, tile_cols, tile_rows, tile_buffers);
 
@@ -2270,6 +2284,11 @@
     read_tx_mode_probs(&fc->tx_probs, &r);
   read_coef_probs(fc, cm->tx_mode, &r);
 
+#if CONFIG_VAR_TX
+  for (k = 0; k < TXFM_PARTITION_CONTEXTS; ++k)
+    vp10_diff_update_prob(&r, &fc->txfm_partition_prob[k]);
+#endif
+
   for (k = 0; k < SKIP_CONTEXTS; ++k)
     vp10_diff_update_prob(&r, &fc->skip_probs[k]);
 
diff --git a/vp10/decoder/decodemv.c b/vp10/decoder/decodemv.c
index 803d0df..a4fb8de 100644
--- a/vp10/decoder/decodemv.c
+++ b/vp10/decoder/decodemv.c
@@ -80,13 +80,17 @@
 
 #if CONFIG_VAR_TX
 static void read_tx_size_inter(VP10_COMMON *cm, MACROBLOCKD *xd,
-                               MB_MODE_INFO *mbmi,
+                               MB_MODE_INFO *mbmi, FRAME_COUNTS *counts,
                                TX_SIZE tx_size, int blk_row, int blk_col,
                                vpx_reader *r) {
   int is_split = 0;
   const int tx_idx = (blk_row >> 1) * 8 + (blk_col >> 1);
   int max_blocks_high = num_4x4_blocks_high_lookup[mbmi->sb_type];
   int max_blocks_wide = num_4x4_blocks_wide_lookup[mbmi->sb_type];
+  int ctx = txfm_partition_context(xd->above_txfm_context + (blk_col >> 1),
+                                   xd->left_txfm_context + (blk_row >> 1),
+                                   tx_size);
+
   if (xd->mb_to_bottom_edge < 0)
     max_blocks_high += xd->mb_to_bottom_edge >> 5;
   if (xd->mb_to_right_edge < 0)
@@ -95,15 +99,21 @@
   if (blk_row >= max_blocks_high || blk_col >= max_blocks_wide)
      return;
 
-  is_split = vpx_read_bit(r);
+  is_split = vpx_read(r, cm->fc->txfm_partition_prob[ctx]);
 
   if (is_split) {
     BLOCK_SIZE bsize = txsize_to_bsize[tx_size];
     int bsl = b_width_log2_lookup[bsize];
     int i;
+
+    if (counts)
+      ++counts->txfm_partition[ctx][1];
+
     if (tx_size == TX_8X8) {
       mbmi->inter_tx_size[tx_idx] = TX_4X4;
       mbmi->tx_size = mbmi->inter_tx_size[tx_idx];
+      txfm_partition_update(xd->above_txfm_context + (blk_col >> 1),
+                            xd->left_txfm_context + (blk_row >> 1), TX_4X4);
       return;
     }
 
@@ -112,7 +122,8 @@
     for (i = 0; i < 4; ++i) {
       int offsetr = blk_row + ((i >> 1) << bsl);
       int offsetc = blk_col + ((i & 0x01) << bsl);
-      read_tx_size_inter(cm, xd, mbmi, tx_size - 1, offsetr, offsetc, r);
+      read_tx_size_inter(cm, xd, mbmi, counts,
+                         tx_size - 1, offsetr, offsetc, r);
     }
   } else {
     int idx, idy;
@@ -121,6 +132,10 @@
       for (idx = 0; idx < (1 << tx_size) / 2; ++idx)
         mbmi->inter_tx_size[tx_idx + (idy << 3) + idx] = tx_size;
     mbmi->tx_size = mbmi->inter_tx_size[tx_idx];
+    if (counts)
+      ++counts->txfm_partition[ctx][0];
+    txfm_partition_update(xd->above_txfm_context + (blk_col >> 1),
+                          xd->left_txfm_context + (blk_row >> 1), tx_size);
   }
 }
 #endif
@@ -764,9 +779,11 @@
     const int width  = num_4x4_blocks_wide_lookup[bsize];
     const int height = num_4x4_blocks_high_lookup[bsize];
     int idx, idy;
+    xd->above_txfm_context = cm->above_txfm_context + mi_col;
+    xd->left_txfm_context = xd->left_txfm_context_buffer + (mi_row & 0x07);
     for (idy = 0; idy < height; idy += bs)
       for (idx = 0; idx < width; idx += bs)
-        read_tx_size_inter(cm, xd, mbmi, max_tx_size,
+        read_tx_size_inter(cm, xd, mbmi, xd->counts, max_tx_size,
                            idy, idx, r);
     if (xd->counts) {
       const int ctx = get_tx_size_context(xd);
diff --git a/vp10/encoder/bitstream.c b/vp10/encoder/bitstream.c
index e4dcc00..1bc3d58 100644
--- a/vp10/encoder/bitstream.c
+++ b/vp10/encoder/bitstream.c
@@ -160,6 +160,10 @@
   const int tx_idx = (blk_row >> 1) * 8 + (blk_col >> 1);
   int max_blocks_high = num_4x4_blocks_high_lookup[mbmi->sb_type];
   int max_blocks_wide = num_4x4_blocks_wide_lookup[mbmi->sb_type];
+  int ctx = txfm_partition_context(xd->above_txfm_context + (blk_col >> 1),
+                                   xd->left_txfm_context + (blk_row >> 1),
+                                   tx_size);
+
   if (xd->mb_to_bottom_edge < 0)
     max_blocks_high += xd->mb_to_bottom_edge >> 5;
   if (xd->mb_to_right_edge < 0)
@@ -169,15 +173,20 @@
      return;
 
   if (tx_size == mbmi->inter_tx_size[tx_idx]) {
-    vpx_write_bit(w, 0);
+    vpx_write(w, 0, cm->fc->txfm_partition_prob[ctx]);
+    txfm_partition_update(xd->above_txfm_context + (blk_col >> 1),
+                          xd->left_txfm_context + (blk_row >> 1), tx_size);
   } else {
     const BLOCK_SIZE bsize = txsize_to_bsize[tx_size];
     int bsl = b_width_log2_lookup[bsize];
     int i;
-    vpx_write_bit(w, 1);
+    vpx_write(w, 1, cm->fc->txfm_partition_prob[ctx]);
 
-    if (tx_size == TX_8X8)
+    if (tx_size == TX_8X8) {
+      txfm_partition_update(xd->above_txfm_context + (blk_col >> 1),
+                            xd->left_txfm_context + (blk_row >> 1), TX_4X4);
       return;
+    }
 
     assert(bsl > 0);
     --bsl;
@@ -188,6 +197,14 @@
     }
   }
 }
+
+static void update_txfm_partition_probs(VP10_COMMON *cm, vpx_writer *w,
+                                        FRAME_COUNTS *counts) {
+  int k;
+  for (k = 0; k < TXFM_PARTITION_CONTEXTS; ++k)
+    vp10_cond_prob_diff_update(w, &cm->fc->txfm_partition_prob[k],
+                               counts->txfm_partition[k]);
+}
 #endif
 
 static void write_selected_tx_size(const VP10_COMMON *cm,
@@ -498,8 +515,8 @@
                                 vpx_writer *w) {
   VP10_COMMON *const cm = &cpi->common;
   const nmv_context *nmvc = &cm->fc->nmvc;
-  const MACROBLOCK *const x = &cpi->td.mb;
-  const MACROBLOCKD *const xd = &x->e_mbd;
+  const MACROBLOCK *x = &cpi->td.mb;
+  const MACROBLOCKD *xd = &x->e_mbd;
   const struct segmentation *const seg = &cm->seg;
 #if CONFIG_MISC_FIXES
   const struct segmentation_probs *const segp = &cm->fc->seg;
@@ -750,6 +767,10 @@
   if (frame_is_intra_only(cm)) {
     write_mb_modes_kf(cm, xd, xd->mi, w);
   } else {
+#if CONFIG_VAR_TX
+    xd->above_txfm_context = cm->above_txfm_context + mi_col;
+    xd->left_txfm_context = xd->left_txfm_context_buffer + (mi_row & 0x07);
+#endif
     pack_inter_mode_mvs(cpi, m, w);
   }
 
@@ -897,6 +918,9 @@
   for (mi_row = tile->mi_row_start; mi_row < tile->mi_row_end;
        mi_row += MI_BLOCK_SIZE) {
     vp10_zero(xd->left_seg_context);
+#if CONFIG_VAR_TX
+    vp10_zero(xd->left_txfm_context_buffer);
+#endif
     for (mi_col = tile->mi_col_start; mi_col < tile->mi_col_end;
          mi_col += MI_BLOCK_SIZE)
       write_modes_sb(cpi, tile, w, tok, tok_end, mi_row, mi_col,
@@ -1385,6 +1409,10 @@
 
   memset(cm->above_seg_context, 0,
          sizeof(*cm->above_seg_context) * mi_cols_aligned_to_sb(cm->mi_cols));
+#if CONFIG_VAR_TX
+  memset(cm->above_txfm_context, 0,
+         sizeof(*cm->above_txfm_context) * mi_cols_aligned_to_sb(cm->mi_cols));
+#endif
 
   for (tile_row = 0; tile_row < tile_rows; tile_row++) {
     for (tile_col = 0; tile_col < tile_cols; tile_col++) {
@@ -1659,6 +1687,11 @@
   update_txfm_probs(cm, &header_bc, counts);
 #endif
   update_coef_probs(cpi, &header_bc);
+
+#if CONFIG_VAR_TX
+  update_txfm_partition_probs(cm, &header_bc, counts);
+#endif
+
   update_skip_probs(cm, &header_bc, counts);
 #if CONFIG_MISC_FIXES
   update_seg_probs(cpi, &header_bc);
diff --git a/vp10/encoder/encodeframe.c b/vp10/encoder/encodeframe.c
index ea09f6e..32f00f7 100644
--- a/vp10/encoder/encodeframe.c
+++ b/vp10/encoder/encodeframe.c
@@ -196,6 +196,11 @@
 
   set_mode_info_offsets(cpi, x, xd, mi_row, mi_col);
 
+#if CONFIG_VAR_TX
+  xd->above_txfm_context = cm->above_txfm_context + mi_col;
+  xd->left_txfm_context = xd->left_txfm_context_buffer + (mi_row & 0x07);
+#endif
+
   mbmi = &xd->mi[0]->mbmi;
 
   // Set up destination pointers.
@@ -1302,6 +1307,9 @@
                             ENTROPY_CONTEXT a[16 * MAX_MB_PLANE],
                             ENTROPY_CONTEXT l[16 * MAX_MB_PLANE],
                             PARTITION_CONTEXT sa[8], PARTITION_CONTEXT sl[8],
+#if CONFIG_VAR_TX
+                            TXFM_CONTEXT ta[8], TXFM_CONTEXT tl[8],
+#endif
                             BLOCK_SIZE bsize) {
   MACROBLOCKD *const xd = &x->e_mbd;
   int p;
@@ -1326,12 +1334,21 @@
          sizeof(*xd->above_seg_context) * mi_width);
   memcpy(xd->left_seg_context + (mi_row & MI_MASK), sl,
          sizeof(xd->left_seg_context[0]) * mi_height);
+#if CONFIG_VAR_TX
+  memcpy(xd->above_txfm_context, ta,
+         sizeof(*xd->above_txfm_context) * mi_width);
+  memcpy(xd->left_txfm_context, tl,
+         sizeof(*xd->left_txfm_context) * mi_height);
+#endif
 }
 
 static void save_context(MACROBLOCK *const x, int mi_row, int mi_col,
                          ENTROPY_CONTEXT a[16 * MAX_MB_PLANE],
                          ENTROPY_CONTEXT l[16 * MAX_MB_PLANE],
                          PARTITION_CONTEXT sa[8], PARTITION_CONTEXT sl[8],
+#if CONFIG_VAR_TX
+                         TXFM_CONTEXT ta[8], TXFM_CONTEXT tl[8],
+#endif
                          BLOCK_SIZE bsize) {
   const MACROBLOCKD *const xd = &x->e_mbd;
   int p;
@@ -1358,6 +1375,12 @@
          sizeof(*xd->above_seg_context) * mi_width);
   memcpy(sl, xd->left_seg_context + (mi_row & MI_MASK),
          sizeof(xd->left_seg_context[0]) * mi_height);
+#if CONFIG_VAR_TX
+  memcpy(ta, xd->above_txfm_context,
+         sizeof(*xd->above_txfm_context) * mi_width);
+  memcpy(tl, xd->left_txfm_context,
+         sizeof(*xd->left_txfm_context) * mi_height);
+#endif
 }
 
 static void encode_b(VP10_COMP *cpi, const TileInfo *const tile,
@@ -1542,6 +1565,9 @@
   BLOCK_SIZE subsize;
   ENTROPY_CONTEXT l[16 * MAX_MB_PLANE], a[16 * MAX_MB_PLANE];
   PARTITION_CONTEXT sl[8], sa[8];
+#if CONFIG_VAR_TX
+  TXFM_CONTEXT tl[8], ta[8];
+#endif
   RD_COST last_part_rdc, none_rdc, chosen_rdc;
   BLOCK_SIZE sub_subsize = BLOCK_4X4;
   int splits_below = 0;
@@ -1562,8 +1588,16 @@
   partition = partition_lookup[bsl][bs_type];
   subsize = get_subsize(bsize, partition);
 
+#if CONFIG_VAR_TX
+  xd->above_txfm_context = cm->above_txfm_context + mi_col;
+  xd->left_txfm_context = xd->left_txfm_context_buffer + (mi_row & 0x07);
+#endif
   pc_tree->partitioning = partition;
-  save_context(x, mi_row, mi_col, a, l, sa, sl, bsize);
+  save_context(x, mi_row, mi_col, a, l, sa, sl,
+#if CONFIG_VAR_TX
+               ta, tl,
+#endif
+               bsize);
 
   if (bsize == BLOCK_16X16 && cpi->oxcf.aq_mode) {
     set_offsets(cpi, tile_info, x, mi_row, mi_col, bsize);
@@ -1603,7 +1637,11 @@
                                  none_rdc.dist);
       }
 
-      restore_context(x, mi_row, mi_col, a, l, sa, sl, bsize);
+      restore_context(x, mi_row, mi_col, a, l, sa, sl,
+#if CONFIG_VAR_TX
+                      ta, tl,
+#endif
+                      bsize);
       mi_8x8[0]->mbmi.sb_type = bs_type;
       pc_tree->partitioning = partition;
     }
@@ -1714,7 +1752,11 @@
     BLOCK_SIZE split_subsize = get_subsize(bsize, PARTITION_SPLIT);
     chosen_rdc.rate = 0;
     chosen_rdc.dist = 0;
-    restore_context(x, mi_row, mi_col, a, l, sa, sl, bsize);
+    restore_context(x, mi_row, mi_col, a, l, sa, sl,
+#if CONFIG_VAR_TX
+                    ta, tl,
+#endif
+                    bsize);
     pc_tree->partitioning = PARTITION_SPLIT;
 
     // Split partition.
@@ -1724,17 +1766,28 @@
       RD_COST tmp_rdc;
       ENTROPY_CONTEXT l[16 * MAX_MB_PLANE], a[16 * MAX_MB_PLANE];
       PARTITION_CONTEXT sl[8], sa[8];
+#if CONFIG_VAR_TX
+      TXFM_CONTEXT tl[8], ta[8];
+#endif
 
       if ((mi_row + y_idx >= cm->mi_rows) || (mi_col + x_idx >= cm->mi_cols))
         continue;
 
-      save_context(x, mi_row, mi_col, a, l, sa, sl, bsize);
+      save_context(x, mi_row, mi_col, a, l, sa, sl,
+#if CONFIG_VAR_TX
+                   ta, tl,
+#endif
+                   bsize);
       pc_tree->split[i]->partitioning = PARTITION_NONE;
       rd_pick_sb_modes(cpi, tile_data, x,
                        mi_row + y_idx, mi_col + x_idx, &tmp_rdc,
                        split_subsize, &pc_tree->split[i]->none, INT64_MAX);
 
-      restore_context(x, mi_row, mi_col, a, l, sa, sl, bsize);
+      restore_context(x, mi_row, mi_col, a, l, sa, sl,
+#if CONFIG_VAR_TX
+                      ta, tl,
+#endif
+                      bsize);
 
       if (tmp_rdc.rate == INT_MAX || tmp_rdc.dist == INT64_MAX) {
         vp10_rd_cost_reset(&chosen_rdc);
@@ -1774,7 +1827,15 @@
     chosen_rdc = none_rdc;
   }
 
-  restore_context(x, mi_row, mi_col, a, l, sa, sl, bsize);
+#if CONFIG_VAR_TX
+  xd->above_txfm_context = cm->above_txfm_context + mi_col;
+  xd->left_txfm_context = xd->left_txfm_context_buffer + (mi_row & 0x07);
+#endif
+  restore_context(x, mi_row, mi_col, a, l, sa, sl,
+#if CONFIG_VAR_TX
+                  ta, tl,
+#endif
+                  bsize);
 
   // We must have chosen a partitioning and encoding or we'll fail later on.
   // No other opportunities for success.
@@ -2046,6 +2107,9 @@
   const int mi_step = num_8x8_blocks_wide_lookup[bsize] / 2;
   ENTROPY_CONTEXT l[16 * MAX_MB_PLANE], a[16 * MAX_MB_PLANE];
   PARTITION_CONTEXT sl[8], sa[8];
+#if CONFIG_VAR_TX
+  TXFM_CONTEXT tl[8], ta[8];
+#endif
   TOKENEXTRA *tp_orig = *tp;
   PICK_MODE_CONTEXT *ctx = &pc_tree->none;
   int i, pl;
@@ -2111,7 +2175,13 @@
     partition_vert_allowed &= force_vert_split;
   }
 
+#if CONFIG_VAR_TX
+  xd->above_txfm_context = cm->above_txfm_context + mi_col;
+  xd->left_txfm_context = xd->left_txfm_context_buffer + (mi_row & 0x07);
+  save_context(x, mi_row, mi_col, a, l, sa, sl, ta, tl, bsize);
+#else
   save_context(x, mi_row, mi_col, a, l, sa, sl, bsize);
+#endif
 
 #if CONFIG_FP_MB_STATS
   if (cpi->use_fp_mb_stats) {
@@ -2257,7 +2327,13 @@
 #endif
       }
     }
+#if CONFIG_VAR_TX
+    xd->above_txfm_context = cm->above_txfm_context + mi_col;
+    xd->left_txfm_context = xd->left_txfm_context_buffer + (mi_row & 0x07);
+    restore_context(x, mi_row, mi_col, a, l, sa, sl, ta, tl, bsize);
+#else
     restore_context(x, mi_row, mi_col, a, l, sa, sl, bsize);
+#endif
   }
 
   // store estimated motion vector
@@ -2322,7 +2398,13 @@
       if (cpi->sf.less_rectangular_check)
         do_rect &= !partition_none_allowed;
     }
+#if CONFIG_VAR_TX
+    xd->above_txfm_context = cm->above_txfm_context + mi_col;
+    xd->left_txfm_context = xd->left_txfm_context_buffer + (mi_row & 0x07);
+    restore_context(x, mi_row, mi_col, a, l, sa, sl, ta, tl, bsize);
+#else
     restore_context(x, mi_row, mi_col, a, l, sa, sl, bsize);
+#endif
   }
 
   // PARTITION_HORZ
@@ -2371,7 +2453,13 @@
         pc_tree->partitioning = PARTITION_HORZ;
       }
     }
+#if CONFIG_VAR_TX
+    xd->above_txfm_context = cm->above_txfm_context + mi_col;
+    xd->left_txfm_context = xd->left_txfm_context_buffer + (mi_row & 0x07);
+    restore_context(x, mi_row, mi_col, a, l, sa, sl, ta, tl, bsize);
+#else
     restore_context(x, mi_row, mi_col, a, l, sa, sl, bsize);
+#endif
   }
   // PARTITION_VERT
   if (partition_vert_allowed &&
@@ -2420,7 +2508,13 @@
         pc_tree->partitioning = PARTITION_VERT;
       }
     }
+#if CONFIG_VAR_TX
+    xd->above_txfm_context = cm->above_txfm_context + mi_col;
+    xd->left_txfm_context = xd->left_txfm_context_buffer + (mi_row & 0x07);
+    restore_context(x, mi_row, mi_col, a, l, sa, sl, ta, tl, bsize);
+#else
     restore_context(x, mi_row, mi_col, a, l, sa, sl, bsize);
+#endif
   }
 
   // TODO(jbb): This code added so that we avoid static analysis
@@ -2461,7 +2555,10 @@
   // Initialize the left context for the new SB row
   memset(&xd->left_context, 0, sizeof(xd->left_context));
   memset(xd->left_seg_context, 0, sizeof(xd->left_seg_context));
-
+#if CONFIG_VAR_TX
+  memset(xd->left_txfm_context_buffer, 0,
+         sizeof(xd->left_txfm_context_buffer));
+#endif
   // Code each SB in the row
   for (mi_col = tile_info->mi_col_start; mi_col < tile_info->mi_col_end;
        mi_col += MI_BLOCK_SIZE) {
@@ -2549,6 +2646,10 @@
          2 * aligned_mi_cols * MAX_MB_PLANE);
   memset(xd->above_seg_context, 0,
          sizeof(*xd->above_seg_context) * aligned_mi_cols);
+#if CONFIG_VAR_TX
+  memset(cm->above_txfm_context, 0,
+         sizeof(*xd->above_txfm_context) * aligned_mi_cols);
+#endif
 }
 
 static int check_dual_ref_flags(VP10_COMP *cpi) {
@@ -2947,6 +3048,140 @@
   ++counts->uv_mode[y_mode][uv_mode];
 }
 
+#if CONFIG_VAR_TX
+static void update_txfm_count(MACROBLOCKD *xd,
+                              FRAME_COUNTS *counts,
+                              TX_SIZE tx_size, int blk_row, int blk_col) {
+  MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
+  int tx_idx = (blk_row >> 1) * 8 + (blk_col >> 1);
+  int max_blocks_high = num_4x4_blocks_high_lookup[mbmi->sb_type];
+  int max_blocks_wide = num_4x4_blocks_wide_lookup[mbmi->sb_type];
+  int ctx = txfm_partition_context(xd->above_txfm_context + (blk_col >> 1),
+                                   xd->left_txfm_context + (blk_row >> 1),
+                                   tx_size);
+  TX_SIZE plane_tx_size = mbmi->inter_tx_size[tx_idx];
+
+  if (xd->mb_to_bottom_edge < 0)
+    max_blocks_high += xd->mb_to_bottom_edge >> 5;
+  if (xd->mb_to_right_edge < 0)
+    max_blocks_wide += xd->mb_to_right_edge >> 5;
+
+  if (blk_row >= max_blocks_high || blk_col >= max_blocks_wide)
+    return;
+
+  if (tx_size == plane_tx_size) {
+    ++counts->txfm_partition[ctx][0];
+    mbmi->tx_size = tx_size;
+    txfm_partition_update(xd->above_txfm_context + (blk_col >> 1),
+                          xd->left_txfm_context + (blk_row >> 1), tx_size);
+  } else {
+    BLOCK_SIZE bsize = txsize_to_bsize[tx_size];
+    int bh = num_4x4_blocks_high_lookup[bsize];
+    int i;
+    ++counts->txfm_partition[ctx][1];
+
+    if (tx_size == TX_8X8) {
+      mbmi->inter_tx_size[tx_idx] = TX_4X4;
+      mbmi->tx_size = TX_4X4;
+      txfm_partition_update(xd->above_txfm_context + (blk_col >> 1),
+                            xd->left_txfm_context + (blk_row >> 1), TX_4X4);
+      return;
+    }
+
+    for (i = 0; i < 4; ++i) {
+      int offsetr = (i >> 1) * bh / 2;
+      int offsetc = (i & 0x01) * bh / 2;
+      update_txfm_count(xd, counts, tx_size - 1,
+                        blk_row + offsetr, blk_col + offsetc);
+    }
+  }
+}
+
+static void tx_partition_count_update(VP10_COMMON *cm,
+                                      MACROBLOCKD *xd,
+                                      BLOCK_SIZE plane_bsize,
+                                      int mi_row, int mi_col,
+                                      FRAME_COUNTS *td_counts) {
+  const int mi_width = num_4x4_blocks_wide_lookup[plane_bsize];
+  const int mi_height = num_4x4_blocks_high_lookup[plane_bsize];
+  TX_SIZE max_tx_size = max_txsize_lookup[plane_bsize];
+  BLOCK_SIZE txb_size = txsize_to_bsize[max_tx_size];
+  int bh = num_4x4_blocks_wide_lookup[txb_size];
+  int idx, idy;
+
+  xd->above_txfm_context = cm->above_txfm_context + mi_col;
+  xd->left_txfm_context = xd->left_txfm_context_buffer + (mi_row & 0x07);
+
+  for (idy = 0; idy < mi_height; idy += bh)
+    for (idx = 0; idx < mi_width; idx += bh)
+      update_txfm_count(xd, td_counts, max_tx_size, idy, idx);
+}
+
+static void set_txfm_context(MACROBLOCKD *xd, TX_SIZE tx_size,
+                             int blk_row, int blk_col) {
+  MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
+  int tx_idx = (blk_row >> 1) * 8 + (blk_col >> 1);
+  int max_blocks_high = num_4x4_blocks_high_lookup[mbmi->sb_type];
+  int max_blocks_wide = num_4x4_blocks_wide_lookup[mbmi->sb_type];
+  TX_SIZE plane_tx_size = mbmi->inter_tx_size[tx_idx];
+
+  if (xd->mb_to_bottom_edge < 0)
+    max_blocks_high += xd->mb_to_bottom_edge >> 5;
+  if (xd->mb_to_right_edge < 0)
+    max_blocks_wide += xd->mb_to_right_edge >> 5;
+
+  if (blk_row >= max_blocks_high || blk_col >= max_blocks_wide)
+    return;
+
+  if (tx_size == plane_tx_size) {
+    mbmi->tx_size = tx_size;
+    txfm_partition_update(xd->above_txfm_context + (blk_col >> 1),
+                          xd->left_txfm_context + (blk_row >> 1), tx_size);
+
+  } else {
+    BLOCK_SIZE bsize = txsize_to_bsize[tx_size];
+    int bsl = b_width_log2_lookup[bsize];
+    int i;
+
+    if (tx_size == TX_8X8) {
+      mbmi->inter_tx_size[tx_idx] = TX_4X4;
+      mbmi->tx_size = TX_4X4;
+      txfm_partition_update(xd->above_txfm_context + (blk_col >> 1),
+                            xd->left_txfm_context + (blk_row >> 1), TX_4X4);
+      return;
+    }
+
+    assert(bsl > 0);
+    --bsl;
+    for (i = 0; i < 4; ++i) {
+      int offsetr = (i >> 1) << bsl;
+      int offsetc = (i & 0x01) << bsl;
+      set_txfm_context(xd, tx_size - 1,
+                       blk_row + offsetr, blk_col + offsetc);
+    }
+  }
+}
+
+static void tx_partition_set_contexts(VP10_COMMON *cm,
+                                      MACROBLOCKD *xd,
+                                      BLOCK_SIZE plane_bsize,
+                                      int mi_row, int mi_col) {
+  const int mi_width = num_4x4_blocks_wide_lookup[plane_bsize];
+  const int mi_height = num_4x4_blocks_high_lookup[plane_bsize];
+  TX_SIZE max_tx_size = max_txsize_lookup[plane_bsize];
+  BLOCK_SIZE txb_size = txsize_to_bsize[max_tx_size];
+  int bh = num_4x4_blocks_wide_lookup[txb_size];
+  int idx, idy;
+
+  xd->above_txfm_context = cm->above_txfm_context + mi_col;
+  xd->left_txfm_context = xd->left_txfm_context_buffer + (mi_row & 0x07);
+
+  for (idy = 0; idy < mi_height; idy += bh)
+    for (idx = 0; idx < mi_width; idx += bh)
+      set_txfm_context(xd, max_tx_size, idy, idx);
+}
+#endif
+
 static void encode_superblock(VP10_COMP *cpi, ThreadData *td,
                               TOKENEXTRA **t, int output_enabled,
                               int mi_row, int mi_col, BLOCK_SIZE bsize,
@@ -3027,12 +3262,15 @@
         !(is_inter_block(mbmi) && (mbmi->skip || seg_skip))) {
 #if CONFIG_VAR_TX
       int tx_size_ctx = get_tx_size_context(xd);
-      if (is_inter_block(mbmi))
+      if (is_inter_block(mbmi)) {
+        tx_partition_count_update(cm, xd, bsize, mi_row, mi_col,
+                                  td->counts);
         inter_block_tx_count_update(cm, xd, mbmi, bsize,
                                     tx_size_ctx, &td->counts->tx);
-      else
+      } else {
         ++get_tx_counts(max_txsize_lookup[bsize], get_tx_size_context(xd),
                         &td->counts->tx)[mbmi->tx_size];
+      }
 #else
       ++get_tx_counts(max_txsize_lookup[bsize], get_tx_size_context(xd),
                       &td->counts->tx)[mbmi->tx_size];
@@ -3072,4 +3310,11 @@
     }
 #endif  // CONFIG_EXT_TX
   }
+
+#if CONFIG_VAR_TX
+  if (cm->tx_mode == TX_MODE_SELECT && mbmi->sb_type >= BLOCK_8X8 &&
+      is_inter_block(mbmi) && !(mbmi->skip || seg_skip) &&
+      !output_enabled)
+    tx_partition_set_contexts(cm, xd, bsize, mi_row, mi_col);
+#endif
 }
diff --git a/vp10/encoder/rdopt.c b/vp10/encoder/rdopt.c
index fc7a107..4a4362e 100644
--- a/vp10/encoder/rdopt.c
+++ b/vp10/encoder/rdopt.c
@@ -1708,6 +1708,7 @@
                             int blk_row, int blk_col, int plane, int block,
                             TX_SIZE tx_size, BLOCK_SIZE plane_bsize,
                             ENTROPY_CONTEXT *ta, ENTROPY_CONTEXT *tl,
+                            TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
                             int *rate, int64_t *dist,
                             int64_t *bsse, int *skip,
                             int64_t ref_best_rd, int *is_cost_valid) {
@@ -1724,9 +1725,12 @@
   ENTROPY_CONTEXT *ptl = tl + blk_row;
   ENTROPY_CONTEXT stxa = 0, stxl = 0;
   int coeff_ctx, i;
+  int ctx = txfm_partition_context(tx_above + (blk_col >> 1),
+                                   tx_left + (blk_row >> 1), tx_size);
+
   int64_t sum_dist = 0, sum_bsse = 0;
   int64_t sum_rd = INT64_MAX;
-  int sum_rate = vp10_cost_bit(128, 1);
+  int sum_rate = vp10_cost_bit(cpi->common.fc->txfm_partition_prob[ctx], 1);
   int all_skip = 1;
   int tmp_eob = 0;
 
@@ -1776,7 +1780,7 @@
     tx_block_rd_b(cpi, x, tx_size, blk_row, blk_col, plane, block,
                   plane_bsize, coeff_ctx, rate, dist, bsse, skip);
     if (tx_size > TX_4X4)
-      *rate += vp10_cost_bit(128, 0);
+      *rate += vp10_cost_bit(cpi->common.fc->txfm_partition_prob[ctx], 0);
     this_rd = RDCOST(x->rdmult, x->rddiv, *rate, *dist);
     tmp_eob = p->eobs[block];
   }
@@ -1799,7 +1803,8 @@
       int offsetc = (i & 0x01) << bsl;
       select_tx_block(cpi, x, blk_row + offsetr, blk_col + offsetc,
                       plane, block + i * sub_step, tx_size - 1,
-                      plane_bsize, ta, tl, &this_rate, &this_dist,
+                      plane_bsize, ta, tl, tx_above, tx_left,
+                      &this_rate, &this_dist,
                       &this_bsse, &this_skip,
                       ref_best_rd - tmp_rd, &this_cost_valid);
       sum_rate += this_rate;
@@ -1818,6 +1823,8 @@
     int idx, idy;
     for (i = 0; i < (1 << tx_size); ++i)
       pta[i] = ptl[i] = !(tmp_eob == 0);
+    txfm_partition_update(tx_above + (blk_col >> 1),
+                          tx_left + (blk_row >> 1), tx_size);
     mbmi->inter_tx_size[tx_idx] = tx_size;
 
     for (idy = 0; idy < (1 << tx_size) / 2; ++idy)
@@ -1867,17 +1874,23 @@
     int block = 0;
     int step = 1 << (max_txsize_lookup[plane_bsize] * 2);
     ENTROPY_CONTEXT ctxa[16], ctxl[16];
+    TXFM_CONTEXT tx_above[8], tx_left[8];
 
     int pnrate = 0, pnskip = 1;
     int64_t pndist = 0, pnsse = 0;
 
     vp10_get_entropy_contexts(bsize, TX_4X4, pd, ctxa, ctxl);
+    memcpy(tx_above, xd->above_txfm_context,
+           sizeof(TXFM_CONTEXT) * (mi_width >> 1));
+    memcpy(tx_left, xd->left_txfm_context,
+           sizeof(TXFM_CONTEXT) * (mi_height >> 1));
 
     for (idy = 0; idy < mi_height; idy += bh) {
       for (idx = 0; idx < mi_width; idx += bh) {
         select_tx_block(cpi, x, idy, idx, 0, block,
                         max_txsize_lookup[plane_bsize], plane_bsize,
-                        ctxa, ctxl, &pnrate, &pndist, &pnsse, &pnskip,
+                        ctxa, ctxl, tx_above, tx_left,
+                        &pnrate, &pndist, &pnsse, &pnskip,
                         ref_best_rd - this_rd, &is_cost_valid);
         *rate += pnrate;
         *distortion += pndist;