Refactor var-tx pipeline to support cb4x4 mode

Replace hard coded 4x4 transform block step size assumption with
scalable table access.

Change-Id: Ib1cc555c2641e5634acdd91ca33217f00aeb0b89
diff --git a/av1/common/pred_common.h b/av1/common/pred_common.h
index 49d09fc..75948b6 100644
--- a/av1/common/pred_common.h
+++ b/av1/common/pred_common.h
@@ -237,11 +237,10 @@
                                                MB_MODE_INFO *mbmi,
                                                BLOCK_SIZE plane_bsize,
                                                int ctx) {
-  const int mi_width = num_4x4_blocks_wide_lookup[plane_bsize];
-  const int mi_height = num_4x4_blocks_high_lookup[plane_bsize];
+  const int mi_width = block_size_wide[plane_bsize] >> tx_size_wide_log2[0];
+  const int mi_height = block_size_high[plane_bsize] >> tx_size_wide_log2[0];
   TX_SIZE max_tx_size = max_txsize_lookup[plane_bsize];
-  BLOCK_SIZE txb_size = txsize_to_bsize[max_tx_size];
-  int bh = num_4x4_blocks_wide_lookup[txb_size];
+  int bh = tx_size_wide_unit[max_tx_size];
   int idx, idy;
 
   for (idy = 0; idy < mi_height; idy += bh)
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index 499da3d..aea295d 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -533,10 +533,10 @@
     const int eob =
         av1_decode_block_tokens(xd, plane, sc, blk_col, blk_row, plane_tx_size,
                                 tx_type, &max_scan_line, r, mbmi->segment_id);
-    inverse_transform_block(
-        xd, plane, tx_type, plane_tx_size,
-        &pd->dst.buf[4 * blk_row * pd->dst.stride + 4 * blk_col],
-        pd->dst.stride, max_scan_line, eob);
+    inverse_transform_block(xd, plane, tx_type, plane_tx_size,
+                            &pd->dst.buf[(blk_row * pd->dst.stride + blk_col)
+                                         << tx_size_wide_log2[0]],
+                            pd->dst.stride, max_scan_line, eob);
     *eob_total += eob;
   } else {
     const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
@@ -1673,7 +1673,6 @@
         const TX_SIZE max_tx_size = max_txsize_rect_lookup[plane_bsize];
         const int bh_var_tx = tx_size_high_unit[max_tx_size];
         const int bw_var_tx = tx_size_wide_unit[max_tx_size];
-
         for (row = 0; row < max_blocks_high; row += bh_var_tx)
           for (col = 0; col < max_blocks_wide; col += bw_var_tx)
             decode_reconstruct_tx(cm, xd, r, mbmi, plane, plane_bsize, row, col,
diff --git a/av1/decoder/decodemv.c b/av1/decoder/decodemv.c
index 97301a9..59a646d 100644
--- a/av1/decoder/decodemv.c
+++ b/av1/decoder/decodemv.c
@@ -346,7 +346,11 @@
     if (counts) ++counts->txfm_partition[ctx][1];
 
     if (tx_size == TX_8X8) {
+      int idx, idy;
       inter_tx_size[0][0] = TX_4X4;
+      for (idy = 0; idy < tx_size_high_unit[tx_size] / 2; ++idy)
+        for (idx = 0; idx < tx_size_wide_unit[tx_size] / 2; ++idx)
+          inter_tx_size[idy][idx] = tx_size;
       mbmi->tx_size = TX_4X4;
       mbmi->min_tx_size = get_min_tx_size(mbmi->tx_size);
       txfm_partition_update(xd->above_txfm_context + tx_col,
@@ -1969,8 +1973,8 @@
         mbmi->tx_size = read_tx_size_intra(cm, xd, r);
 
       if (inter_block) {
-        const int width = num_4x4_blocks_wide_lookup[bsize];
-        const int height = num_4x4_blocks_high_lookup[bsize];
+        const int width = block_size_wide[bsize] >> tx_size_wide_log2[0];
+        const int height = block_size_high[bsize] >> tx_size_high_log2[0];
         int idx, idy;
         for (idy = 0; idy < height; ++idy)
           for (idx = 0; idx < width; ++idx)
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index 9e153c5..27eb8df 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -1327,8 +1327,8 @@
       const TX_SIZE max_tx_size = max_txsize_rect_lookup[bsize];
       const int bh = tx_size_high_unit[max_tx_size];
       const int bw = tx_size_wide_unit[max_tx_size];
-      const int width = num_4x4_blocks_wide_lookup[bsize];
-      const int height = num_4x4_blocks_high_lookup[bsize];
+      const int width = block_size_wide[bsize] >> tx_size_wide_log2[0];
+      const int height = block_size_high[bsize] >> tx_size_wide_log2[0];
       int idx, idy;
       for (idy = 0; idy < height; idy += bh)
         for (idx = 0; idx < width; idx += bw)
@@ -1999,8 +1999,12 @@
 #if CONFIG_VAR_TX
       const struct macroblockd_plane *const pd = &xd->plane[plane];
       BLOCK_SIZE bsize = mbmi->sb_type;
+#if CONFIG_CB4X4
+      const BLOCK_SIZE plane_bsize = get_plane_block_size(bsize, pd);
+#else
       const BLOCK_SIZE plane_bsize =
           get_plane_block_size(AOMMAX(bsize, BLOCK_8X8), pd);
+#endif
 
       const int num_4x4_w =
           block_size_wide[plane_bsize] >> tx_size_wide_log2[0];
diff --git a/av1/encoder/block.h b/av1/encoder/block.h
index 436a000..b8e886b 100644
--- a/av1/encoder/block.h
+++ b/av1/encoder/block.h
@@ -158,9 +158,9 @@
   int mv_row_max;
 
 #if CONFIG_VAR_TX
-  uint8_t blk_skip[MAX_MB_PLANE][MAX_MIB_SIZE * MAX_MIB_SIZE * 4];
+  uint8_t blk_skip[MAX_MB_PLANE][MAX_MIB_SIZE * MAX_MIB_SIZE * 8];
 #if CONFIG_REF_MV
-  uint8_t blk_skip_drl[MAX_MB_PLANE][MAX_MIB_SIZE * MAX_MIB_SIZE * 4];
+  uint8_t blk_skip_drl[MAX_MB_PLANE][MAX_MIB_SIZE * MAX_MIB_SIZE * 8];
 #endif
 #endif
 
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index dcd5e8f..a32d002 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -5261,8 +5261,8 @@
                                       BLOCK_SIZE plane_bsize, int mi_row,
                                       int mi_col, FRAME_COUNTS *td_counts) {
   MACROBLOCKD *xd = &x->e_mbd;
-  const int mi_width = num_4x4_blocks_wide_lookup[plane_bsize];
-  const int mi_height = num_4x4_blocks_high_lookup[plane_bsize];
+  const int mi_width = block_size_wide[plane_bsize] >> tx_size_wide_log2[0];
+  const int mi_height = block_size_high[plane_bsize] >> tx_size_wide_log2[0];
   TX_SIZE max_tx_size = max_txsize_rect_lookup[plane_bsize];
   const int bh = tx_size_high_unit[max_tx_size];
   const int bw = tx_size_wide_unit[max_tx_size];
@@ -5319,8 +5319,8 @@
 static void tx_partition_set_contexts(const AV1_COMMON *const cm,
                                       MACROBLOCKD *xd, BLOCK_SIZE plane_bsize,
                                       int mi_row, int mi_col) {
-  const int mi_width = num_4x4_blocks_wide_lookup[plane_bsize];
-  const int mi_height = num_4x4_blocks_high_lookup[plane_bsize];
+  const int mi_width = block_size_wide[plane_bsize] >> tx_size_wide_log2[0];
+  const int mi_height = block_size_high[plane_bsize] >> tx_size_high_log2[0];
   TX_SIZE max_tx_size = max_txsize_rect_lookup[plane_bsize];
   const int bh = tx_size_high_unit[max_tx_size];
   const int bw = tx_size_wide_unit[max_tx_size];
@@ -5482,8 +5482,8 @@
     av1_encode_sb((AV1_COMMON *)cm, x, block_size);
 #if CONFIG_VAR_TX
     if (mbmi->skip) mbmi->min_tx_size = get_min_tx_size(mbmi->tx_size);
-    av1_tokenize_sb_vartx(cpi, td, t, dry_run, mi_row, mi_col,
-                          AOMMAX(bsize, BLOCK_8X8), rate);
+    av1_tokenize_sb_vartx(cpi, td, t, dry_run, mi_row, mi_col, block_size,
+                          rate);
 #else
     av1_tokenize_sb(cpi, td, t, dry_run, block_size, rate);
 #endif
diff --git a/av1/encoder/encodemb.c b/av1/encoder/encodemb.c
index 6573af9..25c2de8 100644
--- a/av1/encoder/encodemb.c
+++ b/av1/encoder/encodemb.c
@@ -632,7 +632,7 @@
 #endif
 #if CONFIG_VAR_TX
   int i;
-  const int bwl = b_width_log2_lookup[plane_bsize];
+  int bw = block_size_wide[plane_bsize] >> tx_size_wide_log2[0];
 #endif
   dst = &pd->dst
              .buf[(blk_row * pd->dst.stride + blk_col) << tx_size_wide_log2[0]];
@@ -646,9 +646,9 @@
 
 #if CONFIG_VAR_TX
   // Assert not magic number (uninitialized).
-  assert(x->blk_skip[plane][(blk_row << bwl) + blk_col] != 234);
+  assert(x->blk_skip[plane][blk_row * bw + blk_col] != 234);
 
-  if (x->blk_skip[plane][(blk_row << bwl) + blk_col] == 0) {
+  if (x->blk_skip[plane][blk_row * bw + blk_col] == 0) {
 #else
   {
 #endif
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 32b9798..eb9a951 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -3150,8 +3150,11 @@
   int txb_w = tx_size_wide_unit[tx_size];
 
   int src_stride = p->src.stride;
-  uint8_t *src = &p->src.buf[4 * blk_row * src_stride + 4 * blk_col];
-  uint8_t *dst = &pd->dst.buf[4 * blk_row * pd->dst.stride + 4 * blk_col];
+  uint8_t *src =
+      &p->src.buf[(blk_row * src_stride + blk_col) << tx_size_wide_log2[0]];
+  uint8_t *dst =
+      &pd->dst
+           .buf[(blk_row * pd->dst.stride + blk_col) << tx_size_wide_log2[0]];
 #if CONFIG_AOM_HIGHBITDEPTH
   DECLARE_ALIGNED(16, uint16_t, rec_buffer16[MAX_TX_SQUARE]);
   uint8_t *rec_buffer;
@@ -3161,7 +3164,8 @@
   int max_blocks_high = block_size_high[plane_bsize];
   int max_blocks_wide = block_size_wide[plane_bsize];
   const int diff_stride = max_blocks_wide;
-  const int16_t *diff = &p->src_diff[4 * (blk_row * diff_stride + blk_col)];
+  const int16_t *diff =
+      &p->src_diff[(blk_row * diff_stride + blk_col) << tx_size_wide_log2[0]];
   int txb_coeff_cost;
 
   assert(tx_size < TX_SIZES_ALL);
@@ -3206,10 +3210,11 @@
     int blocks_height = AOMMIN(txb_h, max_blocks_high - blk_row);
     int blocks_width = AOMMIN(txb_w, max_blocks_wide - blk_col);
     tmp = 0;
-    for (idy = 0; idy < blocks_height; idy += 2) {
-      for (idx = 0; idx < blocks_width; idx += 2) {
-        const int16_t *d = diff + 4 * idy * diff_stride + 4 * idx;
-        tmp += aom_sum_squares_2d_i16(d, diff_stride, 8);
+    for (idy = 0; idy < blocks_height; ++idy) {
+      for (idx = 0; idx < blocks_width; ++idx) {
+        const int16_t *d =
+            diff + ((idy * diff_stride + idx) << tx_size_wide_log2[0]);
+        tmp += aom_sum_squares_2d_i16(d, diff_stride, 4);
       }
     }
   } else {
@@ -3247,11 +3252,13 @@
       int blocks_height = AOMMIN(txb_h, max_blocks_high - blk_row);
       int blocks_width = AOMMIN(txb_w, max_blocks_wide - blk_col);
       tmp = 0;
-      for (idy = 0; idy < blocks_height; idy += 2) {
-        for (idx = 0; idx < blocks_width; idx += 2) {
-          uint8_t *const s = src + 4 * idy * src_stride + 4 * idx;
-          uint8_t *const r = rec_buffer + 4 * idy * MAX_TX_SIZE + 4 * idx;
-          cpi->fn_ptr[BLOCK_8X8].vf(s, src_stride, r, MAX_TX_SIZE, &this_dist);
+      for (idy = 0; idy < blocks_height; ++idy) {
+        for (idx = 0; idx < blocks_width; ++idx) {
+          uint8_t *const s =
+              src + ((idy * src_stride + idx) << tx_size_wide_log2[0]);
+          uint8_t *const r =
+              rec_buffer + ((idy * MAX_TX_SIZE + idx) << tx_size_wide_log2[0]);
+          cpi->fn_ptr[BLOCK_4X4].vf(s, src_stride, r, MAX_TX_SIZE, &this_dist);
           tmp += this_dist;
         }
       }
@@ -3428,8 +3435,8 @@
   if (is_cost_valid) {
     const struct macroblockd_plane *const pd = &xd->plane[0];
     const BLOCK_SIZE plane_bsize = get_plane_block_size(bsize, pd);
-    const int mi_width = num_4x4_blocks_wide_lookup[plane_bsize];
-    const int mi_height = num_4x4_blocks_high_lookup[plane_bsize];
+    const int mi_width = block_size_wide[plane_bsize] >> tx_size_wide_log2[0];
+    const int mi_height = block_size_high[plane_bsize] >> tx_size_high_log2[0];
     const TX_SIZE max_tx_size = max_txsize_rect_lookup[plane_bsize];
     const int bh = tx_size_high_unit[max_tx_size];
     const int bw = tx_size_wide_unit[max_tx_size];
@@ -3445,7 +3452,7 @@
     RD_STATS pn_rd_stats;
     av1_init_rd_stats(&pn_rd_stats);
 
-    av1_get_entropy_contexts(bsize, TX_4X4, pd, ctxa, ctxl);
+    av1_get_entropy_contexts(bsize, 0, pd, ctxa, ctxl);
     memcpy(tx_above, xd->above_txfm_context,
            sizeof(TXFM_CONTEXT) * (mi_width >> 1));
     memcpy(tx_left, xd->left_txfm_context,
@@ -3552,8 +3559,8 @@
   TX_SIZE best_tx_size[MAX_MIB_SIZE][MAX_MIB_SIZE];
   TX_SIZE best_tx = max_txsize_lookup[bsize];
   TX_SIZE best_min_tx_size = TX_SIZES_ALL;
-  uint8_t best_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE * 4];
-  const int n4 = 1 << (num_pels_log2_lookup[bsize] - 4);
+  uint8_t best_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE * 8];
+  const int n4 = 1 << (num_pels_log2_lookup[bsize] - 2 * tx_size_wide_log2[0]);
   int idx, idy;
   int prune = 0;
   const int count32 = 1 << (2 * (cpi->common.mib_size_log2 -
@@ -3716,8 +3723,8 @@
   for (plane = 1; plane < MAX_MB_PLANE; ++plane) {
     const struct macroblockd_plane *const pd = &xd->plane[plane];
     const BLOCK_SIZE plane_bsize = get_plane_block_size(bsize, pd);
-    const int mi_width = num_4x4_blocks_wide_lookup[plane_bsize];
-    const int mi_height = num_4x4_blocks_high_lookup[plane_bsize];
+    const int mi_width = block_size_wide[plane_bsize] >> tx_size_wide_log2[0];
+    const int mi_height = block_size_high[plane_bsize] >> tx_size_high_log2[0];
     const TX_SIZE max_tx_size = max_txsize_rect_lookup[plane_bsize];
     const int bh = tx_size_high_unit[max_tx_size];
     const int bw = tx_size_wide_unit[max_tx_size];
@@ -3729,7 +3736,7 @@
     RD_STATS pn_rd_stats;
     av1_init_rd_stats(&pn_rd_stats);
 
-    av1_get_entropy_contexts(bsize, TX_4X4, pd, ta, tl);
+    av1_get_entropy_contexts(bsize, 0, pd, ta, tl);
 
     for (idy = 0; idy < mi_height; idy += bh) {
       for (idx = 0; idx < mi_width; idx += bw) {