rect_tx_ext: work with var_tx

Change-Id: Ie2c34490dc50cb242bcd701308e6b55243883b15
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index 2331307..9241572 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -355,7 +355,12 @@
     return;
   }
 
+#if CONFIG_RECT_TX_EXT
+  if (tx_size == mbmi->inter_tx_size[tx_row][tx_col] ||
+      mbmi->tx_size == quarter_txsize_lookup[mbmi->sb_type]) {
+#else
   if (tx_size == mbmi->inter_tx_size[tx_row][tx_col]) {
+#endif
 #if CONFIG_NEW_MULTISYMBOL
     aom_write_symbol(w, 0, ec_ctx->txfm_partition_cdf[ctx], 2);
 #else
@@ -364,6 +369,7 @@
 
     txfm_partition_update(xd->above_txfm_context + blk_col,
                           xd->left_txfm_context + blk_row, tx_size, tx_size);
+    // TODO(yuec): set correct txfm partition update for qttx
   } else {
     const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
     const int bsl = tx_size_wide_unit[sub_txs];
@@ -427,11 +433,11 @@
 
     aom_write_symbol(w, depth, ec_ctx->tx_size_cdf[tx_size_cat][tx_size_ctx],
                      tx_size_cat + 2);
-#if CONFIG_EXT_TX && CONFIG_RECT_TX && CONFIG_RECT_TX_EXT
+#if CONFIG_RECT_TX_EXT && (CONFIG_EXT_TX || CONFIG_VAR_TX)
     if (is_quarter_tx_allowed(xd, mbmi, is_inter) && tx_size != coded_tx_size)
       aom_write(w, tx_size == quarter_txsize_lookup[bsize],
                 cm->fc->quarter_tx_size_prob);
-#endif  // CONFIG_EXT_TX && CONFIG_RECT_TX && CONFIG_RECT_TX_EXT
+#endif
   }
 }
 
@@ -963,15 +969,30 @@
     token_stats->cost += tmp_token_stats.cost;
 #endif
   } else {
+#if CONFIG_RECT_TX_EXT
+    int is_qttx = plane_tx_size == quarter_txsize_lookup[plane_bsize];
+    const TX_SIZE sub_txs = is_qttx ? plane_tx_size : sub_tx_size_map[tx_size];
+#else
     const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
+#endif
     const int bsl = tx_size_wide_unit[sub_txs];
     int i;
 
     assert(bsl > 0);
 
     for (i = 0; i < 4; ++i) {
+#if CONFIG_RECT_TX_EXT
+      int is_wide_tx = tx_size_wide_unit[sub_txs] > tx_size_high_unit[sub_txs];
+      const int offsetr =
+          is_qttx ? (is_wide_tx ? i * tx_size_high_unit[sub_txs] : 0)
+                  : blk_row + (i >> 1) * bsl;
+      const int offsetc =
+          is_qttx ? (is_wide_tx ? 0 : i * tx_size_wide_unit[sub_txs])
+                  : blk_col + (i & 0x01) * bsl;
+#else
       const int offsetr = blk_row + (i >> 1) * bsl;
       const int offsetc = blk_col + (i & 0x01) * bsl;
+#endif
       const int step = tx_size_wide_unit[sub_txs] * tx_size_high_unit[sub_txs];
 
       if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue;
@@ -1748,6 +1769,15 @@
         for (idx = 0; idx < width; idx += bw)
           write_tx_size_vartx(cm, xd, mbmi, max_tx_size, height != width, idy,
                               idx, w);
+#if CONFIG_RECT_TX_EXT
+      if (is_quarter_tx_allowed(xd, mbmi, is_inter_block(mbmi)) &&
+          quarter_txsize_lookup[bsize] != max_tx_size &&
+          (mbmi->tx_size == quarter_txsize_lookup[bsize] ||
+           mbmi->tx_size == max_tx_size)) {
+        aom_write(w, mbmi->tx_size != max_tx_size,
+                  cm->fc->quarter_tx_size_prob);
+      }
+#endif
     } else {
       set_txfm_ctxs(mbmi->tx_size, xd->n8_w, xd->n8_h, skip, xd);
       write_selected_tx_size(cm, xd, w);
@@ -4551,11 +4581,11 @@
 #if CONFIG_LOOP_RESTORATION
   encode_restoration(cm, header_bc);
 #endif  // CONFIG_LOOP_RESTORATION
-#if CONFIG_EXT_TX && CONFIG_RECT_TX && CONFIG_RECT_TX_EXT
+#if CONFIG_RECT_TX_EXT && (CONFIG_EXT_TX || CONFIG_VAR_TX)
   if (cm->tx_mode == TX_MODE_SELECT)
     av1_cond_prob_diff_update(header_bc, &cm->fc->quarter_tx_size_prob,
                               cm->counts.quarter_tx_size, probwt);
-#endif  // CONFIG_EXT_TX && CONFIG_RECT_TX && CONFIG_RECT_TX_EXT
+#endif
 #if CONFIG_LV_MAP
   av1_write_txb_probs(cpi, header_bc);
 #endif  // CONFIG_LV_MAP