Set up txb size properly for TX64X64

TX64X64 uses 32x32 coeff buffer

Change-Id: Ied4279807207176d590af4c1fc4bb648a618d158
diff --git a/av1/common/txb_common.h b/av1/common/txb_common.h
index 2d94ae5..383bde3 100644
--- a/av1/common/txb_common.h
+++ b/av1/common/txb_common.h
@@ -66,6 +66,30 @@
   }
 }
 
+static INLINE int get_txb_bwl(TX_SIZE tx_size) {
+#if CONFIG_TX64X64
+  if (tx_size == TX_64X64 || tx_size == TX_64X32 || tx_size == TX_32X64)
+    tx_size = TX_32X32;
+#endif
+  return tx_size_wide_log2[tx_size];
+}
+
+static INLINE int get_txb_wide(TX_SIZE tx_size) {
+#if CONFIG_TX64X64
+  if (tx_size == TX_64X64 || tx_size == TX_64X32 || tx_size == TX_32X64)
+    tx_size = TX_32X32;
+#endif
+  return tx_size_wide[tx_size];
+}
+
+static INLINE int get_txb_high(TX_SIZE tx_size) {
+#if CONFIG_TX64X64
+  if (tx_size == TX_64X64 || tx_size == TX_64X32 || tx_size == TX_32X64)
+    tx_size = TX_32X32;
+#endif
+  return tx_size_high[tx_size];
+}
+
 static INLINE void get_base_count_mag(int *mag, int *count,
                                       const tran_low_t *tcoeffs, int bwl,
                                       int height, int row, int col) {
diff --git a/av1/decoder/decodetxb.c b/av1/decoder/decodetxb.c
index 56619f9..8c6b5bf 100644
--- a/av1/decoder/decodetxb.c
+++ b/av1/decoder/decodetxb.c
@@ -75,9 +75,9 @@
   const int16_t *const dequant =
       xd->plane[plane].seg_dequant_QTX[mbmi->segment_id];
   const int shift = av1_get_tx_scale(tx_size);
-  const int bwl = tx_size_wide_log2[tx_size];
-  const int width = tx_size_wide[tx_size];
-  const int height = tx_size_high[tx_size];
+  const int bwl = get_txb_bwl(tx_size);
+  const int width = get_txb_wide(tx_size);
+  const int height = get_txb_high(tx_size);
   int cul_level = 0;
   uint8_t levels_buf[TX_PAD_2D];
   uint8_t *const levels = set_levels(levels_buf, width);
diff --git a/av1/encoder/encodetxb.c b/av1/encoder/encodetxb.c
index 7094e83..05889dc 100644
--- a/av1/encoder/encodetxb.c
+++ b/av1/encoder/encodetxb.c
@@ -329,9 +329,9 @@
   const int16_t *scan = scan_order->scan;
   const int seg_eob = av1_get_max_eob(tx_size);
   int c;
-  const int bwl = tx_size_wide_log2[tx_size];
-  const int width = tx_size_wide[tx_size];
-  const int height = tx_size_high[tx_size];
+  const int bwl = get_txb_bwl(tx_size);
+  const int width = get_txb_wide(tx_size);
+  const int height = get_txb_high(tx_size);
   int update_eob = -1;
   FRAME_CONTEXT *ec_ctx = xd->tile_ctx;
   uint8_t levels_buf[TX_PAD_2D];
@@ -673,9 +673,9 @@
   int c, cost;
   int txb_skip_ctx = txb_ctx->txb_skip_ctx;
 
-  const int bwl = tx_size_wide_log2[tx_size];
-  const int width = tx_size_wide[tx_size];
-  const int height = tx_size_high[tx_size];
+  const int bwl = get_txb_bwl(tx_size);
+  const int width = get_txb_wide(tx_size);
+  const int height = get_txb_high(tx_size);
 
   const SCAN_ORDER *const scan_order = get_scan(cm, tx_size, tx_type, mbmi);
   const int16_t *scan = scan_order->scan;
@@ -2125,9 +2125,9 @@
   const tran_low_t *tcoeff = BLOCK_OFFSET(p->coeff, block);
   const int16_t *dequant = p->dequant_QTX;
   const int seg_eob = av1_get_max_eob(tx_size);
-  const int bwl = tx_size_wide_log2[tx_size];
-  const int width = tx_size_wide[tx_size];
-  const int height = tx_size_high[tx_size];
+  const int bwl = get_txb_bwl(tx_size);
+  const int width = get_txb_wide(tx_size);
+  const int height = get_txb_high(tx_size);
   const int is_inter = is_inter_block(mbmi);
   const SCAN_ORDER *const scan_order = get_scan(cm, tx_size, tx_type, mbmi);
   const LV_MAP_COEFF_COST txb_costs = x->coeff_costs[txs_ctx][plane_type];
@@ -2221,9 +2221,9 @@
   TXB_CTX txb_ctx;
   get_txb_ctx(plane_bsize, tx_size, plane, pd->above_context + blk_col,
               pd->left_context + blk_row, &txb_ctx);
-  const int bwl = tx_size_wide_log2[tx_size];
-  const int width = tx_size_wide[tx_size];
-  const int height = tx_size_high[tx_size];
+  const int bwl = get_txb_bwl(tx_size);
+  const int width = get_txb_wide(tx_size);
+  const int height = get_txb_high(tx_size);
   uint8_t levels_buf[TX_PAD_2D];
   uint8_t *const levels = set_levels(levels_buf, width);
   DECLARE_ALIGNED(16, uint8_t, level_counts[MAX_TX_SQUARE]);