Make rectangular txfm in EXT_TX work with VAR_TX

Adapt rectangular txfm experiment to syntax/tokenization/loopfilter
framework of VAR_TX

Change-Id: Idcb005ecf5b3712de3e1cccb0d811ca16d87af24
diff --git a/vp10/common/common_data.h b/vp10/common/common_data.h
index 2f42c37..65e99e1 100644
--- a/vp10/common/common_data.h
+++ b/vp10/common/common_data.h
@@ -368,6 +368,36 @@
 #endif  // CONFIG_EXT_TX
 /* clang-format on */
 
+static const TX_SIZE txsize_horz_map[TX_SIZES_ALL] = {
+  TX_4X4,    // TX_4X4
+  TX_8X8,    // TX_8X8
+  TX_16X16,  // TX_16X16
+  TX_32X32,  // TX_32X32
+#if CONFIG_EXT_TX
+  TX_4X4,    // TX_4X8
+  TX_8X8,    // TX_8X4
+  TX_8X8,    // TX_8X16
+  TX_16X16,  // TX_16X8
+  TX_16X16,  // TX_16X32
+  TX_32X32   // TX_32X16
+#endif       // CONFIG_EXT_TX
+};
+
+static const TX_SIZE txsize_vert_map[TX_SIZES_ALL] = {
+  TX_4X4,    // TX_4X4
+  TX_8X8,    // TX_8X8
+  TX_16X16,  // TX_16X16
+  TX_32X32,  // TX_32X32
+#if CONFIG_EXT_TX
+  TX_8X8,    // TX_4X8
+  TX_4X4,    // TX_8X4
+  TX_16X16,  // TX_8X16
+  TX_8X8,    // TX_16X8
+  TX_32X32,  // TX_16X32
+  TX_16X16   // TX_32X16
+#endif       // CONFIG_EXT_TX
+};
+
 static const BLOCK_SIZE txsize_to_bsize[TX_SIZES_ALL] = {
   BLOCK_4X4,    // TX_4X4
   BLOCK_8X8,    // TX_8X8
diff --git a/vp10/common/loopfilter.c b/vp10/common/loopfilter.c
index 1c50c4e..eaa0e7e 100644
--- a/vp10/common/loopfilter.c
+++ b/vp10/common/loopfilter.c
@@ -1252,6 +1252,16 @@
                                       sb_type, ss_x, ss_y)
                 : mbmi->inter_tx_size[blk_row][blk_col];
 
+#if CONFIG_EXT_TX && CONFIG_RECT_TX
+      tx_size_r =
+          VPXMIN(txsize_horz_map[tx_size], cm->above_txfm_context[mi_col + c]);
+      tx_size_c = VPXMIN(txsize_vert_map[tx_size],
+                         cm->left_txfm_context[(mi_row + r) & MAX_MIB_MASK]);
+
+      cm->above_txfm_context[mi_col + c] = txsize_horz_map[tx_size];
+      cm->left_txfm_context[(mi_row + r) & MAX_MIB_MASK] =
+          txsize_vert_map[tx_size];
+#else
       tx_size_r = VPXMIN(tx_size, cm->above_txfm_context[mi_col + c]);
       tx_size_c =
           VPXMIN(tx_size, cm->left_txfm_context[(mi_row + r) & MAX_MIB_MASK]);
@@ -1259,6 +1269,7 @@
       cm->above_txfm_context[mi_col + c] = tx_size;
       cm->left_txfm_context[(mi_row + r) & MAX_MIB_MASK] = tx_size;
 #endif
+#endif
 
       // Build masks based on the transform size of each block
       // handle vertical mask
diff --git a/vp10/common/onyxc_int.h b/vp10/common/onyxc_int.h
index dfa04b5..fcd328f 100644
--- a/vp10/common/onyxc_int.h
+++ b/vp10/common/onyxc_int.h
@@ -661,6 +661,12 @@
   for (i = 0; i < len; ++i) txfm_ctx[i] = tx_size;
 }
 
+static INLINE void set_txfm_ctxs(TX_SIZE tx_size, int n8_w, int n8_h,
+                                 const MACROBLOCKD *xd) {
+  set_txfm_ctx(xd->above_txfm_context, txsize_horz_map[tx_size], n8_w);
+  set_txfm_ctx(xd->left_txfm_context, txsize_vert_map[tx_size], n8_h);
+}
+
 static INLINE void txfm_partition_update(TXFM_CONTEXT *above_ctx,
                                          TXFM_CONTEXT *left_ctx,
                                          TX_SIZE tx_size) {
diff --git a/vp10/decoder/decodeframe.c b/vp10/decoder/decodeframe.c
index 8288dd8..706954b 100644
--- a/vp10/decoder/decodeframe.c
+++ b/vp10/decoder/decodeframe.c
@@ -324,11 +324,7 @@
 
   if (blk_row >= max_blocks_high || blk_col >= max_blocks_wide) return;
 
-  if (tx_size == plane_tx_size
-#if CONFIG_EXT_TX && CONFIG_RECT_TX
-      || plane_tx_size >= TX_SIZES
-#endif  // CONFIG_EXT_TX && CONFIG_RECT_TX
-      ) {
+  if (tx_size == plane_tx_size) {
     PLANE_TYPE plane_type = (plane == 0) ? PLANE_TYPE_Y : PLANE_TYPE_UV;
     TX_TYPE tx_type = get_tx_type(plane_type, xd, block, plane_tx_size);
     const scan_order *sc = get_scan(plane_tx_size, tx_type, 1);
@@ -361,7 +357,7 @@
 }
 #endif  // CONFIG_VAR_TX
 
-#if !CONFIG_VAR_TX || CONFIG_SUPERTX
+#if !CONFIG_VAR_TX || CONFIG_SUPERTX || (CONFIG_EXT_TX && CONFIG_RECT_TX)
 static int reconstruct_inter_block(MACROBLOCKD *const xd,
 #if CONFIG_ANS
                                    struct AnsDecoder *const r,
@@ -533,8 +529,7 @@
   xd->above_txfm_context = cm->above_txfm_context + mi_col;
   xd->left_txfm_context =
       xd->left_txfm_context_buffer + (mi_row & MAX_MIB_MASK);
-  set_txfm_ctx(xd->left_txfm_context, xd->mi[0]->mbmi.tx_size, bh);
-  set_txfm_ctx(xd->above_txfm_context, xd->mi[0]->mbmi.tx_size, bw);
+  set_txfm_ctxs(xd->mi[0]->mbmi.tx_size, bw, bh, xd);
 #endif
 }
 
@@ -1324,24 +1319,44 @@
         // TODO(jingning): This can be simplified for decoder performance.
         const BLOCK_SIZE plane_bsize =
             get_plane_block_size(VPXMAX(bsize, BLOCK_8X8), pd);
-#if CONFIG_EXT_TX && CONFIG_RECT_TX
-        const TX_SIZE max_tx_size = plane ? max_txsize_lookup[plane_bsize]
-                                          : max_txsize_rect_lookup[plane_bsize];
-#else
         const TX_SIZE max_tx_size = max_txsize_lookup[plane_bsize];
-#endif  // CONFIG_EXT_TX && CONFIG_RECT_TX
         int bw = num_4x4_blocks_wide_txsize_lookup[max_tx_size];
         int bh = num_4x4_blocks_high_txsize_lookup[max_tx_size];
         const int step = num_4x4_blocks_txsize_lookup[max_tx_size];
         int block = 0;
+#if CONFIG_EXT_TX && CONFIG_RECT_TX
+        const TX_SIZE tx_size =
+            plane ? dec_get_uv_tx_size(mbmi, pd->n4_wl, pd->n4_hl)
+                  : mbmi->tx_size;
 
-        for (row = 0; row < num_4x4_h; row += bh) {
-          for (col = 0; col < num_4x4_w; col += bw) {
-            decode_reconstruct_tx(xd, r, mbmi, plane, plane_bsize, block, row,
-                                  col, max_tx_size, &eobtotal);
-            block += step;
+        if (tx_size >= TX_SIZES) {  // rect txsize is used
+          const int stepr = num_4x4_blocks_high_txsize_lookup[tx_size];
+          const int stepc = num_4x4_blocks_wide_txsize_lookup[tx_size];
+          const int max_blocks_wide =
+              num_4x4_w +
+              (xd->mb_to_right_edge >= 0 ? 0 : xd->mb_to_right_edge >>
+                                                   (5 + pd->subsampling_x));
+          const int max_blocks_high =
+              num_4x4_h +
+              (xd->mb_to_bottom_edge >= 0 ? 0 : xd->mb_to_bottom_edge >>
+                                                    (5 + pd->subsampling_y));
+
+          for (row = 0; row < max_blocks_high; row += stepr)
+            for (col = 0; col < max_blocks_wide; col += stepc)
+              eobtotal += reconstruct_inter_block(xd, r, mbmi->segment_id,
+                                                  plane, row, col, tx_size);
+        } else {
+#endif
+          for (row = 0; row < num_4x4_h; row += bh) {
+            for (col = 0; col < num_4x4_w; col += bw) {
+              decode_reconstruct_tx(xd, r, mbmi, plane, plane_bsize, block, row,
+                                    col, max_tx_size, &eobtotal);
+              block += step;
+            }
           }
+#if CONFIG_EXT_TX && CONFIG_RECT_TX
         }
+#endif
 #else
         const TX_SIZE tx_size =
             plane ? dec_get_uv_tx_size(mbmi, pd->n4_wl, pd->n4_hl)
diff --git a/vp10/decoder/decodemv.c b/vp10/decoder/decodemv.c
index 2adb482..d8ce8fd 100644
--- a/vp10/decoder/decodemv.c
+++ b/vp10/decoder/decodemv.c
@@ -315,12 +315,12 @@
       return tx_size;
     }
   } else {
-#if CONFIG_EXT_TX && CONFIG_RECT_TX && !CONFIG_VAR_TX
+#if CONFIG_EXT_TX && CONFIG_RECT_TX
     assert(IMPLIES(tx_mode == ONLY_4X4, bsize == BLOCK_4X4));
     return max_txsize_rect_lookup[bsize];
 #else
     return TX_4X4;
-#endif  // CONFIG_EXT_TX && CONFIG_RECT_TX && !CONFIG_VAR_TX
+#endif
   }
 }
 
@@ -1642,8 +1642,7 @@
             mbmi->inter_tx_size[idy >> 1][idx >> 1] = mbmi->tx_size;
       }
 
-      set_txfm_ctx(xd->left_txfm_context, mbmi->tx_size, xd->n8_h);
-      set_txfm_ctx(xd->above_txfm_context, mbmi->tx_size, xd->n8_w);
+      set_txfm_ctxs(mbmi->tx_size, xd->n8_w, xd->n8_h, xd);
     }
 #else
   if (inter_block)
diff --git a/vp10/encoder/bitstream.c b/vp10/encoder/bitstream.c
index 60e54e4..e07b4bc 100644
--- a/vp10/encoder/bitstream.c
+++ b/vp10/encoder/bitstream.c
@@ -1118,14 +1118,11 @@
         for (idx = 0; idx < width; idx += bs)
           write_tx_size_vartx(cm, xd, mbmi, max_tx_size, idy, idx, w);
     } else {
-      set_txfm_ctx(xd->left_txfm_context, mbmi->tx_size, xd->n8_h);
-      set_txfm_ctx(xd->above_txfm_context, mbmi->tx_size, xd->n8_w);
-
+      set_txfm_ctxs(mbmi->tx_size, xd->n8_w, xd->n8_h, xd);
       write_selected_tx_size(cm, xd, w);
     }
   } else {
-    set_txfm_ctx(xd->left_txfm_context, mbmi->tx_size, xd->n8_h);
-    set_txfm_ctx(xd->above_txfm_context, mbmi->tx_size, xd->n8_w);
+    set_txfm_ctxs(mbmi->tx_size, xd->n8_w, xd->n8_h, xd);
 #else
     write_selected_tx_size(cm, xd, w);
 #endif
@@ -1640,8 +1637,14 @@
       const int num_4x4_w = num_4x4_blocks_wide_lookup[plane_bsize];
       const int num_4x4_h = num_4x4_blocks_high_lookup[plane_bsize];
       int row, col;
+#if CONFIG_EXT_TX && CONFIG_RECT_TX
+      TX_SIZE tx_size =
+          plane ? get_uv_tx_size(mbmi, &xd->plane[plane]) : mbmi->tx_size;
 
+      if (is_inter_block(mbmi) && tx_size < TX_SIZES) {
+#else
       if (is_inter_block(mbmi)) {
+#endif
         const TX_SIZE max_tx_size = max_txsize_lookup[plane_bsize];
         const BLOCK_SIZE txb_size = txsize_to_bsize[max_tx_size];
         int bw = num_4x4_blocks_wide_lookup[txb_size];
@@ -1659,8 +1662,9 @@
                            : m->mbmi.tx_size;
         BLOCK_SIZE txb_size = txsize_to_bsize[tx];
         int bw = num_4x4_blocks_wide_lookup[txb_size];
+        int bh = num_4x4_blocks_high_lookup[txb_size];
 
-        for (row = 0; row < num_4x4_h; row += bw)
+        for (row = 0; row < num_4x4_h; row += bh)
           for (col = 0; col < num_4x4_w; col += bw)
             pack_mb_tokens(w, tok, tok_end, cm->bit_depth, tx);
       }
diff --git a/vp10/encoder/encodeframe.c b/vp10/encoder/encodeframe.c
index 0efd6fb..c32a7d5 100644
--- a/vp10/encoder/encodeframe.c
+++ b/vp10/encoder/encodeframe.c
@@ -2229,8 +2229,7 @@
         update_partition_context(xd, mi_row, mi_col, subsize, bsize);
 #endif
 #if CONFIG_VAR_TX
-      set_txfm_ctx(xd->left_txfm_context, supertx_size, xd->n8_h);
-      set_txfm_ctx(xd->above_txfm_context, supertx_size, mi_height);
+      set_txfm_ctxs(supertx_size, mi_width, mi_height, xd);
 #endif  // CONFIG_VAR_TX
       return;
     } else {
@@ -5027,8 +5026,13 @@
 
     vp10_encode_sb(x, VPXMAX(bsize, BLOCK_8X8));
 #if CONFIG_VAR_TX
-    vp10_tokenize_sb_inter(cpi, td, t, !output_enabled, mi_row, mi_col,
-                           VPXMAX(bsize, BLOCK_8X8));
+#if CONFIG_EXT_TX && CONFIG_RECT_TX
+    if (mbmi->tx_size >= TX_SIZES)
+      vp10_tokenize_sb(cpi, td, t, !output_enabled, VPXMAX(bsize, BLOCK_8X8));
+    else
+#endif
+      vp10_tokenize_sb_inter(cpi, td, t, !output_enabled, mi_row, mi_col,
+                             VPXMAX(bsize, BLOCK_8X8));
 #else
     vp10_tokenize_sb(cpi, td, t, !output_enabled, VPXMAX(bsize, BLOCK_8X8));
 #endif
@@ -5108,13 +5112,22 @@
     TX_SIZE tx_size;
     // The new intra coding scheme requires no change of transform size
     if (is_inter_block(mbmi))
+#if CONFIG_EXT_TX && CONFIG_RECT_TX
+    {
       tx_size = VPXMIN(tx_mode_to_biggest_tx_size[cm->tx_mode],
                        max_txsize_lookup[bsize]);
+      if (txsize_sqr_map[max_txsize_rect_lookup[bsize]] <= tx_size)
+        tx_size = max_txsize_rect_lookup[bsize];
+      if (xd->lossless[mbmi->segment_id]) tx_size = TX_4X4;
+    }
+#else
+      tx_size = VPXMIN(tx_mode_to_biggest_tx_size[cm->tx_mode],
+                       max_txsize_lookup[bsize]);
+#endif
     else
       tx_size = (bsize >= BLOCK_8X8) ? mbmi->tx_size : TX_4X4;
     mbmi->tx_size = tx_size;
-    set_txfm_ctx(xd->left_txfm_context, tx_size, xd->n8_h);
-    set_txfm_ctx(xd->above_txfm_context, tx_size, xd->n8_w);
+    set_txfm_ctxs(tx_size, xd->n8_w, xd->n8_h, xd);
   }
 #endif
 }
diff --git a/vp10/encoder/encodemb.c b/vp10/encoder/encodemb.c
index 1e825e3..a6a4f5d 100644
--- a/vp10/encoder/encodemb.c
+++ b/vp10/encoder/encodemb.c
@@ -980,6 +980,9 @@
     int idx, idy;
     int block = 0;
     int step = num_4x4_blocks_txsize_lookup[max_tx_size];
+#if CONFIG_EXT_TX && CONFIG_RECT_TX
+    const TX_SIZE tx_size = plane ? get_uv_tx_size(mbmi, pd) : mbmi->tx_size;
+#endif
     vp10_get_entropy_contexts(bsize, TX_4X4, pd, ctx.ta[plane], ctx.tl[plane]);
 #else
     const struct macroblockd_plane *const pd = &xd->plane[plane];
@@ -991,13 +994,22 @@
     arg.tl = ctx.tl[plane];
 
 #if CONFIG_VAR_TX
-    for (idy = 0; idy < mi_height; idy += bh) {
-      for (idx = 0; idx < mi_width; idx += bh) {
-        encode_block_inter(plane, block, idy, idx, plane_bsize, max_tx_size,
-                           &arg);
-        block += step;
+#if CONFIG_EXT_TX && CONFIG_RECT_TX
+    if (tx_size >= TX_SIZES) {
+      vp10_foreach_transformed_block_in_plane(xd, bsize, plane, encode_block,
+                                              &arg);
+    } else {
+#endif
+      for (idy = 0; idy < mi_height; idy += bh) {
+        for (idx = 0; idx < mi_width; idx += bh) {
+          encode_block_inter(plane, block, idy, idx, plane_bsize, max_tx_size,
+                             &arg);
+          block += step;
+        }
       }
+#if CONFIG_EXT_TX && CONFIG_RECT_TX
     }
+#endif
 #else
     vp10_foreach_transformed_block_in_plane(xd, bsize, plane, encode_block,
                                             &arg);
diff --git a/vp10/encoder/firstpass.c b/vp10/encoder/firstpass.c
index 7fe38a1..bd7e297 100644
--- a/vp10/encoder/firstpass.c
+++ b/vp10/encoder/firstpass.c
@@ -2584,7 +2584,7 @@
 
         cpi->rc.is_bwd_ref_frame = 1;
         cpi->bwd_fb_idx = cpi->alt_fb_idx;
-        cpi->alt_fb_idx = cpi->arf_map[0];;
+        cpi->alt_fb_idx = cpi->arf_map[0];
         cpi->arf_map[0] = tmp;
       } else {
         cpi->rc.is_bwd_ref_frame = 0;
diff --git a/vp10/encoder/rdopt.c b/vp10/encoder/rdopt.c
index b060878..9ecd2de 100644
--- a/vp10/encoder/rdopt.c
+++ b/vp10/encoder/rdopt.c
@@ -4220,13 +4220,13 @@
   const int num_4x4_w = num_4x4_blocks_wide_txsize_lookup[tx_size];
   const int num_4x4_h = num_4x4_blocks_high_txsize_lookup[tx_size];
 
-#if CONFIG_EXT_TX && CONFIG_RECT_TX && !CONFIG_VAR_TX
+#if CONFIG_EXT_TX && CONFIG_RECT_TX
   assert(IMPLIES(xd->lossless[mi->mbmi.segment_id], tx_size == TX_4X4));
   assert(IMPLIES(!xd->lossless[mi->mbmi.segment_id],
                  tx_size == max_txsize_rect_lookup[mi->mbmi.sb_type]));
 #else
   assert(tx_size == TX_4X4);
-#endif  // CONFIG_EXT_TX && CONFIG_RECT_TX && !CONFIG_VAR_TX
+#endif
   assert(tx_type == DCT_DCT);
 
   vp10_build_inter_predictor_sub8x8(xd, 0, i, ir, ic, mi_row, mi_col);
@@ -4746,12 +4746,12 @@
   const int has_second_rf = has_second_ref(mbmi);
   const int inter_mode_mask = cpi->sf.inter_mode_mask[bsize];
   MB_MODE_INFO_EXT *const mbmi_ext = x->mbmi_ext;
-#if CONFIG_EXT_TX && CONFIG_RECT_TX && !CONFIG_VAR_TX
+#if CONFIG_EXT_TX && CONFIG_RECT_TX
   mbmi->tx_size =
       xd->lossless[mbmi->segment_id] ? TX_4X4 : max_txsize_rect_lookup[bsize];
 #else
   mbmi->tx_size = TX_4X4;
-#endif  // CONFIG_EXT_TX && CONFIG_RECT_TX && !CONFIG_VAR_TX
+#endif
 
   vp10_zero(*bsi);
 
@@ -10512,6 +10512,10 @@
 
   // macroblock modes
   *mbmi = best_mbmode;
+#if CONFIG_VAR_TX && CONFIG_EXT_TX && CONFIG_RECT_TX
+  mbmi->inter_tx_size[0][0] = mbmi->tx_size;
+#endif
+
   x->skip |= best_skip2;
   if (!is_inter_block(&best_mbmode)) {
     for (i = 0; i < 4; i++) xd->mi[0]->bmi[i].as_mode = best_bmodes[i].as_mode;