Rework txk_type indexing system for chroma component

Use the row and column indexes to fetch txk_type, which allows the
chroma components to derive the tx type from the corresponding luma
components. It improves the coding performance of txk-sel by 0.18%.

Change-Id: I3f4bca5839e13ae95e51053e76cd86fe58202ac9
diff --git a/av1/common/blockd.h b/av1/common/blockd.h
index 66fe204..5125352 100644
--- a/av1/common/blockd.h
+++ b/av1/common/blockd.h
@@ -1114,11 +1114,12 @@
 }
 
 static INLINE TX_TYPE av1_get_tx_type(PLANE_TYPE plane_type,
-                                      const MACROBLOCKD *xd, int block,
-                                      TX_SIZE tx_size) {
+                                      const MACROBLOCKD *xd, int blk_row,
+                                      int blk_col, int block, TX_SIZE tx_size) {
   const MODE_INFO *const mi = xd->mi[0];
   const MB_MODE_INFO *const mbmi = &mi->mbmi;
-
+  (void)blk_row;
+  (void)blk_col;
 #if CONFIG_INTRABC && (!CONFIG_EXT_TX || CONFIG_TXK_SEL)
   // TODO(aconverse@google.com): Handle INTRABC + EXT_TX + TXK_SEL
   if (is_intrabc_block(mbmi)) return DCT_DCT;
@@ -1126,11 +1127,15 @@
 
 #if CONFIG_TXK_SEL
   TX_TYPE tx_type;
-  if (plane_type != PLANE_TYPE_Y || xd->lossless[mbmi->segment_id] ||
-      txsize_sqr_map[tx_size] >= TX_32X32) {
+  if (xd->lossless[mbmi->segment_id] || txsize_sqr_map[tx_size] >= TX_32X32) {
     tx_type = DCT_DCT;
   } else {
-    tx_type = mbmi->txk_type[block];
+    if (plane_type == PLANE_TYPE_Y)
+      tx_type = mbmi->txk_type[(blk_row << 4) + blk_col];
+    else if (is_inter_block(mbmi))
+      tx_type = mbmi->txk_type[(blk_row << 5) + (blk_col << 1)];
+    else
+      tx_type = intra_mode_to_tx_type_context[mbmi->uv_mode];
   }
   assert(tx_type >= DCT_DCT && tx_type < TX_TYPES);
   return tx_type;
diff --git a/av1/common/idct.c b/av1/common/idct.c
index 38de47b..0c6ee3a 100644
--- a/av1/common/idct.c
+++ b/av1/common/idct.c
@@ -2125,7 +2125,8 @@
   tran_low_t *dqcoeff = BLOCK_OFFSET(pd->dqcoeff, block);
   const PLANE_TYPE plane_type = get_plane_type(plane);
   const TX_SIZE tx_size = av1_get_tx_size(plane, xd);
-  const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, block, tx_size);
+  const TX_TYPE tx_type =
+      av1_get_tx_type(plane_type, xd, blk_row, blk_col, block, tx_size);
   const int dst_stride = pd->dst.stride;
   uint8_t *dst =
       &pd->dst.buf[(blk_row * dst_stride + blk_col) << tx_size_wide_log2[0]];
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index a0190d7..4a18776 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -610,9 +610,11 @@
     av1_read_coeffs_txb_facade(cm, xd, r, row, col, block_idx, plane,
                                pd->dqcoeff, tx_size, &max_scan_line, &eob);
     // tx_type will be read out in av1_read_coeffs_txb_facade
-    const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, block_idx, tx_size);
+    const TX_TYPE tx_type =
+        av1_get_tx_type(plane_type, xd, row, col, block_idx, tx_size);
 #else   // CONFIG_LV_MAP
-    const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, block_idx, tx_size);
+    const TX_TYPE tx_type =
+        av1_get_tx_type(plane_type, xd, row, col, block_idx, tx_size);
     const SCAN_ORDER *scan_order = get_scan(cm, tx_size, tx_type, mbmi);
     int16_t max_scan_line = 0;
     const int eob =
@@ -643,7 +645,8 @@
 #endif  // CONFIG_DPCM_INTRA
     }
 #else   // !CONFIG_PVQ
-    const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, block_idx, tx_size);
+    const TX_TYPE tx_type =
+        av1_get_tx_type(plane_type, xd, row, col, block_idx, tx_size);
     av1_pvq_decode_helper2(cm, xd, mbmi, plane, row, col, tx_size, tx_type);
 #endif  // !CONFIG_PVQ
   }
@@ -693,10 +696,10 @@
                                pd->dqcoeff, tx_size, &max_scan_line, &eob);
     // tx_type will be read out in av1_read_coeffs_txb_facade
     const TX_TYPE tx_type =
-        av1_get_tx_type(plane_type, xd, block, plane_tx_size);
+        av1_get_tx_type(plane_type, xd, blk_row, blk_col, block, plane_tx_size);
 #else   // CONFIG_LV_MAP
     const TX_TYPE tx_type =
-        av1_get_tx_type(plane_type, xd, block, plane_tx_size);
+        av1_get_tx_type(plane_type, xd, blk_row, blk_col, block, plane_tx_size);
     const SCAN_ORDER *sc = get_scan(cm, plane_tx_size, tx_type, mbmi);
     int16_t max_scan_line = 0;
     const int eob = av1_decode_block_tokens(
@@ -759,10 +762,12 @@
   av1_read_coeffs_txb_facade(cm, xd, r, row, col, block_idx, plane, pd->dqcoeff,
                              tx_size, &max_scan_line, &eob);
   // tx_type will be read out in av1_read_coeffs_txb_facade
-  const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, block_idx, tx_size);
+  const TX_TYPE tx_type =
+      av1_get_tx_type(plane_type, xd, row, col, block_idx, tx_size);
 #else   // CONFIG_LV_MAP
   int16_t max_scan_line = 0;
-  const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, block_idx, tx_size);
+  const TX_TYPE tx_type =
+      av1_get_tx_type(plane_type, xd, row, col, block_idx, tx_size);
   const SCAN_ORDER *scan_order =
       get_scan(cm, tx_size, tx_type, &xd->mi[0]->mbmi);
   const int eob =
@@ -779,7 +784,8 @@
                             tx_type, tx_size, dst, pd->dst.stride,
                             max_scan_line, eob);
 #else
-  const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, block_idx, tx_size);
+  const TX_TYPE tx_type =
+      av1_get_tx_type(plane_type, xd, row, col, block_idx, tx_size);
   eob = av1_pvq_decode_helper2(cm, xd, &xd->mi[0]->mbmi, plane, row, col,
                                tx_size, tx_type);
 #endif
diff --git a/av1/decoder/decodemv.c b/av1/decoder/decodemv.c
index fe904de..166721b 100644
--- a/av1/decoder/decodemv.c
+++ b/av1/decoder/decodemv.c
@@ -985,7 +985,8 @@
                       int supertx_enabled,
 #endif
 #if CONFIG_TXK_SEL
-                      int block, int plane, TX_SIZE tx_size,
+                      int blk_row, int blk_col, int block, int plane,
+                      TX_SIZE tx_size,
 #endif
                       aom_reader *r) {
   MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
@@ -1004,7 +1005,8 @@
 #else
   // only y plane's tx_type is transmitted
   if (plane > 0) return;
-  TX_TYPE *tx_type = &mbmi->txk_type[block];
+  (void)block;
+  TX_TYPE *tx_type = &mbmi->txk_type[(blk_row << 4) + blk_col];
 #endif
 
   if (!FIXED_TX_TYPE) {
diff --git a/av1/decoder/decodemv.h b/av1/decoder/decodemv.h
index 9538e96..162cf32 100644
--- a/av1/decoder/decodemv.h
+++ b/av1/decoder/decodemv.h
@@ -37,7 +37,8 @@
                       int supertx_enabled,
 #endif
 #if CONFIG_TXK_SEL
-                      int block, int plane, TX_SIZE tx_size,
+                      int blk_row, int blk_col, int block, int plane,
+                      TX_SIZE tx_size,
 #endif
                       aom_reader *r);
 
diff --git a/av1/decoder/decodetxb.c b/av1/decoder/decodetxb.c
index d499a7d..bf4f01b 100644
--- a/av1/decoder/decodetxb.c
+++ b/av1/decoder/decodetxb.c
@@ -73,15 +73,20 @@
   *eob = 0;
   if (all_zero) {
     *max_scan_line = 0;
+#if CONFIG_TXK_SEL
+    if (plane == 0) mbmi->txk_type[(blk_row << 4) + blk_col] = DCT_DCT;
+#endif
     return 0;
   }
 
   (void)blk_row;
   (void)blk_col;
 #if CONFIG_TXK_SEL
-  av1_read_tx_type(cm, xd, block, plane, get_min_tx_size(tx_size), r);
+  av1_read_tx_type(cm, xd, blk_row, blk_col, block, plane,
+                   get_min_tx_size(tx_size), r);
 #endif
-  const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, block, tx_size);
+  const TX_TYPE tx_type =
+      av1_get_tx_type(plane_type, xd, blk_row, blk_col, block, tx_size);
   const SCAN_ORDER *const scan_order = get_scan(cm, tx_size, tx_type, mbmi);
   const int16_t *scan = scan_order->scan;
   const int16_t *iscan = scan_order->iscan;
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index f5a1a2b..1c0db25 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -1563,7 +1563,8 @@
                        const int supertx_enabled,
 #endif
 #if CONFIG_TXK_SEL
-                       int block, int plane, TX_SIZE tx_size,
+                       int blk_row, int blk_col, int block, int plane,
+                       TX_SIZE tx_size,
 #endif
                        aom_writer *w) {
   MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
@@ -1583,7 +1584,8 @@
   // Only y plane's tx_type is transmitted
   if (plane > 0) return;
   PLANE_TYPE plane_type = get_plane_type(plane);
-  const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, block, tx_size);
+  TX_TYPE tx_type =
+      av1_get_tx_type(plane_type, xd, blk_row, blk_col, block, tx_size);
 #endif
 
   if (!FIXED_TX_TYPE) {
diff --git a/av1/encoder/bitstream.h b/av1/encoder/bitstream.h
index 5a57047..fd23074 100644
--- a/av1/encoder/bitstream.h
+++ b/av1/encoder/bitstream.h
@@ -42,7 +42,8 @@
                        const int supertx_enabled,
 #endif
 #if CONFIG_TXK_SEL
-                       int block, int plane, TX_SIZE tx_size,
+                       int blk_row, int blk_col, int block, int plane,
+                       TX_SIZE tx_size,
 #endif
                        aom_writer *w);
 
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index adcfeb4..9531abe 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -5902,7 +5902,8 @@
   (void)blk_col;
   // Only y plane's tx_type is updated
   if (plane > 0) return;
-  const TX_TYPE tx_type = av1_get_tx_type(PLANE_TYPE_Y, xd, block, tx_size);
+  TX_TYPE tx_type =
+      av1_get_tx_type(PLANE_TYPE_Y, xd, blk_row, blk_col, block, tx_size);
 #endif
 #if CONFIG_EXT_TX
   if (get_ext_tx_types(tx_size, bsize, is_inter, cm->reduced_tx_set_used) > 1 &&
diff --git a/av1/encoder/encodemb.c b/av1/encoder/encodemb.c
index b397f73..e7b136a 100644
--- a/av1/encoder/encodemb.c
+++ b/av1/encoder/encodemb.c
@@ -125,7 +125,8 @@
 #if !CONFIG_LV_MAP
 
 static int optimize_b_greedy(const AV1_COMMON *cm, MACROBLOCK *mb, int plane,
-                             int block, TX_SIZE tx_size, int ctx) {
+                             int blk_row, int blk_col, int block,
+                             TX_SIZE tx_size, int ctx) {
   MACROBLOCKD *const xd = &mb->e_mbd;
   struct macroblock_plane *const p = &mb->plane[plane];
   struct macroblockd_plane *const pd = &xd->plane[plane];
@@ -138,7 +139,8 @@
   const PLANE_TYPE plane_type = pd->plane_type;
   const int16_t *const dequant_ptr = pd->dequant;
   const uint8_t *const band_translate = get_band_translate(tx_size);
-  const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, block, tx_size);
+  TX_TYPE tx_type =
+      av1_get_tx_type(plane_type, xd, blk_row, blk_col, block, tx_size);
   const SCAN_ORDER *const scan_order =
       get_scan(cm, tx_size, tx_type, &xd->mi[0]->mbmi);
   const int16_t *const scan = scan_order->scan;
@@ -458,7 +460,8 @@
 #else
   int ctx = combine_entropy_contexts(*a, *l);
 #endif  // CONFIG_VAR_TX
-  return optimize_b_greedy(cm, mb, plane, block, tx_size, ctx);
+  return optimize_b_greedy(cm, mb, plane, blk_row, blk_col, block, tx_size,
+                           ctx);
 #else   // !CONFIG_LV_MAP
   TXB_CTX txb_ctx;
   get_txb_ctx(plane_bsize, tx_size, plane, a, l, &txb_ctx);
@@ -508,7 +511,8 @@
   struct macroblockd_plane *const pd = &xd->plane[plane];
 #endif
   PLANE_TYPE plane_type = get_plane_type(plane);
-  const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, block, tx_size);
+  TX_TYPE tx_type =
+      av1_get_tx_type(plane_type, xd, blk_row, blk_col, block, tx_size);
 
 #if CONFIG_AOM_QM || CONFIG_NEW_QUANT
   const int is_inter = is_inter_block(mbmi);
@@ -753,7 +757,8 @@
 
   if (x->pvq_skip[plane]) return;
 #endif
-  const TX_TYPE tx_type = av1_get_tx_type(pd->plane_type, xd, block, tx_size);
+  TX_TYPE tx_type =
+      av1_get_tx_type(pd->plane_type, xd, blk_row, blk_col, block, tx_size);
 #if CONFIG_LGT
   PREDICTION_MODE mode = get_prediction_mode(xd->mi[0], plane, tx_size, block);
   av1_inverse_transform_block(xd, dqcoeff, mode, tx_type, tx_size, dst,
@@ -1345,7 +1350,8 @@
   struct macroblockd_plane *const pd = &xd->plane[plane];
   tran_low_t *dqcoeff = BLOCK_OFFSET(pd->dqcoeff, block);
   PLANE_TYPE plane_type = get_plane_type(plane);
-  const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, block, tx_size);
+  const TX_TYPE tx_type =
+      av1_get_tx_type(plane_type, xd, blk_row, blk_col, block, tx_size);
   uint16_t *eob = &p->eobs[block];
   const int dst_stride = pd->dst.stride;
   uint8_t *dst =
diff --git a/av1/encoder/encodetxb.c b/av1/encoder/encodetxb.c
index d067b5c..3aa4c18 100644
--- a/av1/encoder/encodetxb.c
+++ b/av1/encoder/encodetxb.c
@@ -78,7 +78,8 @@
   MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
   const PLANE_TYPE plane_type = get_plane_type(plane);
   const TX_SIZE txs_ctx = get_txsize_context(tx_size);
-  const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, block, tx_size);
+  const TX_TYPE tx_type =
+      av1_get_tx_type(plane_type, xd, blk_row, blk_col, block, tx_size);
   const SCAN_ORDER *const scan_order = get_scan(cm, tx_size, tx_type, mbmi);
   const int16_t *scan = scan_order->scan;
   const int16_t *iscan = scan_order->iscan;
@@ -96,7 +97,8 @@
 
   if (eob == 0) return;
 #if CONFIG_TXK_SEL
-  av1_write_tx_type(cm, xd, block, plane, get_min_tx_size(tx_size), w);
+  av1_write_tx_type(cm, xd, blk_row, blk_col, block, plane,
+                    get_min_tx_size(tx_size), w);
 #endif
 
   nz_map = cm->fc->nz_map[txs_ctx][plane_type];
@@ -292,9 +294,8 @@
   MACROBLOCKD *const xd = &x->e_mbd;
   TX_SIZE txs_ctx = get_txsize_context(tx_size);
   const PLANE_TYPE plane_type = get_plane_type(plane);
-  (void)blk_row;
-  (void)blk_col;
-  const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, block, tx_size);
+  const TX_TYPE tx_type =
+      av1_get_tx_type(plane_type, xd, blk_row, blk_col, block, tx_size);
   MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
   const struct macroblock_plane *p = &x->plane[plane];
   const int eob = p->eobs[block];
@@ -1466,9 +1467,8 @@
   MACROBLOCKD *const xd = &x->e_mbd;
   const PLANE_TYPE plane_type = get_plane_type(plane);
   const TX_SIZE txs_ctx = get_txsize_context(tx_size);
-  (void)blk_row;
-  (void)blk_col;
-  const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, block, tx_size);
+  const TX_TYPE tx_type =
+      av1_get_tx_type(plane_type, xd, blk_row, blk_col, block, tx_size);
   const MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
   const struct macroblock_plane *p = &x->plane[plane];
   struct macroblockd_plane *pd = &xd->plane[plane];
@@ -1543,7 +1543,8 @@
   const uint16_t eob = p->eobs[block];
   const tran_low_t *qcoeff = BLOCK_OFFSET(p->qcoeff, block);
   const PLANE_TYPE plane_type = pd->plane_type;
-  const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, block, tx_size);
+  const TX_TYPE tx_type =
+      av1_get_tx_type(plane_type, xd, blk_row, blk_col, block, tx_size);
   const SCAN_ORDER *const scan_order = get_scan(cm, tx_size, tx_type, mbmi);
   (void)plane_bsize;
 
@@ -1568,7 +1569,8 @@
   const tran_low_t *qcoeff = BLOCK_OFFSET(p->qcoeff, block);
   tran_low_t *tcoeff = BLOCK_OFFSET(x->mbmi_ext->tcoeff[plane], block);
   const int segment_id = mbmi->segment_id;
-  const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, block, tx_size);
+  const TX_TYPE tx_type =
+      av1_get_tx_type(plane_type, xd, blk_row, blk_col, block, tx_size);
   const SCAN_ORDER *const scan_order = get_scan(cm, tx_size, tx_type, mbmi);
   const int16_t *scan = scan_order->scan;
   const int16_t *iscan = scan_order->iscan;
@@ -1876,9 +1878,9 @@
   av1_invalid_rd_stats(&best_rd_stats);
 
   for (tx_type = txk_start; tx_type <= txk_end; ++tx_type) {
-    if (plane == 0) mbmi->txk_type[block] = tx_type;
-    const TX_TYPE ref_tx_type =
-        av1_get_tx_type(get_plane_type(plane), xd, block, tx_size);
+    if (plane == 0) mbmi->txk_type[(blk_row << 4) + blk_col] = tx_type;
+    TX_TYPE ref_tx_type = av1_get_tx_type(get_plane_type(plane), xd, blk_row,
+                                          blk_col, block, tx_size);
     if (tx_type != ref_tx_type) {
       // use av1_get_tx_type() to check if the tx_type is valid for the current
       // mode if it's not, we skip it here.
@@ -1908,6 +1910,7 @@
         av1_cost_coeffs(cpi, x, plane, blk_row, blk_col, block, tx_size,
                         scan_order, a, l, use_fast_coef_costing);
     int rd = RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
+
     if (rd < best_rd) {
       best_rd = rd;
       best_rd_stats = this_rd_stats;
@@ -1918,7 +1921,13 @@
 
   av1_merge_rd_stats(rd_stats, &best_rd_stats);
 
-  if (plane == 0) mbmi->txk_type[block] = best_tx_type;
+  //  if (x->plane[plane].eobs[block] == 0)
+  //    if (best_tx_type != DCT_DCT)
+  //      exit(0);
+
+  if (best_eob == 0 && is_inter_block(mbmi)) best_tx_type = DCT_DCT;
+
+  if (plane == 0) mbmi->txk_type[(blk_row << 4) + blk_col] = best_tx_type;
   x->plane[plane].txb_entropy_ctx[block] = best_eob;
 
   if (!is_inter_block(mbmi)) {
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 9b6f3dc..f718a1b 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -1745,8 +1745,8 @@
 #endif  // !CONFIG_PVQ
 
         const PLANE_TYPE plane_type = get_plane_type(plane);
-        const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, block, tx_size);
-
+        TX_TYPE tx_type =
+            av1_get_tx_type(plane_type, xd, blk_row, blk_col, block, tx_size);
         av1_inverse_transform_block(xd, dqcoeff,
 #if CONFIG_LGT
                                     xd->mi[0]->mbmi.mode,
@@ -1821,8 +1821,9 @@
         av1_block_index_to_raster_order(tx_size, block);
     const PREDICTION_MODE mode =
         (plane == 0) ? get_y_mode(xd->mi[0], block_raster_idx) : mbmi->uv_mode;
-    const TX_TYPE tx_type = av1_get_tx_type(
-        (plane == 0) ? PLANE_TYPE_Y : PLANE_TYPE_UV, xd, block, tx_size);
+    TX_TYPE tx_type =
+        av1_get_tx_type((plane == 0) ? PLANE_TYPE_Y : PLANE_TYPE_UV, xd,
+                        blk_row, blk_col, block, tx_size);
     if (av1_use_dpcm_intra(plane, mode, tx_type, mbmi)) {
       int8_t skip;
       av1_encode_block_intra_dpcm(cm, x, mode, plane, block, blk_row, blk_col,
@@ -1882,7 +1883,8 @@
   }
 #if !CONFIG_PVQ
   const PLANE_TYPE plane_type = get_plane_type(plane);
-  const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, block, tx_size);
+  const TX_TYPE tx_type =
+      av1_get_tx_type(plane_type, xd, blk_row, blk_col, block, tx_size);
   const SCAN_ORDER *scan_order = get_scan(cm, tx_size, tx_type, mbmi);
   this_rd_stats.rate =
       av1_cost_coeffs(cpi, x, plane, blk_row, blk_col, block, tx_size,
@@ -2462,7 +2464,6 @@
   TX_TYPE best_tx_type = DCT_DCT;
 #if CONFIG_TXK_SEL
   TX_TYPE best_txk_type[MAX_SB_SQUARE / (TX_SIZE_W_MIN * TX_SIZE_H_MIN)];
-  const int num_blk = bsize_to_num_blk(bs);
 #endif  // CONFIG_TXK_SEL
   const int tx_select = cm->tx_mode == TX_MODE_SELECT;
   const int is_inter = is_inter_block(mbmi);
@@ -2504,8 +2505,7 @@
                       rect_tx_size);
         if (rd < best_rd) {
 #if CONFIG_TXK_SEL
-          memcpy(best_txk_type, mbmi->txk_type,
-                 sizeof(best_txk_type[0]) * num_blk);
+          memcpy(best_txk_type, mbmi->txk_type, sizeof(best_txk_type[0]) * 256);
 #endif
           best_tx_type = tx_type;
           best_tx_size = rect_tx_size;
@@ -2611,8 +2611,7 @@
       last_rd = rd;
       if (rd < best_rd) {
 #if CONFIG_TXK_SEL
-        memcpy(best_txk_type, mbmi->txk_type,
-               sizeof(best_txk_type[0]) * num_blk);
+        memcpy(best_txk_type, mbmi->txk_type, sizeof(best_txk_type[0]) * 256);
 #endif
         best_tx_type = tx_type;
         best_tx_size = n;
@@ -2628,7 +2627,7 @@
   mbmi->tx_size = best_tx_size;
   mbmi->tx_type = best_tx_type;
 #if CONFIG_TXK_SEL
-  memcpy(mbmi->txk_type, best_txk_type, sizeof(best_txk_type[0]) * num_blk);
+  memcpy(mbmi->txk_type, best_txk_type, sizeof(best_txk_type[0]) * 256);
 #endif
 
 #if CONFIG_VAR_TX
@@ -3093,8 +3092,8 @@
                                     src_stride, dst, dst_stride, xd->bd);
 #endif
           if (is_lossless) {
-            const TX_TYPE tx_type =
-                av1_get_tx_type(PLANE_TYPE_Y, xd, block, tx_size);
+            TX_TYPE tx_type =
+                av1_get_tx_type(PLANE_TYPE_Y, xd, 0, 0, block, tx_size);
             const SCAN_ORDER *scan_order =
                 get_scan(cm, tx_size, tx_type, &xd->mi[0]->mbmi);
             const int coeff_ctx =
@@ -3142,8 +3141,8 @@
           } else {
             int64_t dist;
             unsigned int tmp;
-            const TX_TYPE tx_type =
-                av1_get_tx_type(PLANE_TYPE_Y, xd, block, tx_size);
+            TX_TYPE tx_type =
+                av1_get_tx_type(PLANE_TYPE_Y, xd, 0, 0, block, tx_size);
             const SCAN_ORDER *scan_order =
                 get_scan(cm, tx_size, tx_type, &xd->mi[0]->mbmi);
             const int coeff_ctx =
@@ -3301,9 +3300,8 @@
         aom_subtract_block(tx_height, tx_width, src_diff, 8, src, src_stride,
                            dst, dst_stride);
 #endif  // !CONFIG_PVQ
-
-        const TX_TYPE tx_type =
-            av1_get_tx_type(PLANE_TYPE_Y, xd, block, tx_size);
+        TX_TYPE tx_type =
+            av1_get_tx_type(PLANE_TYPE_Y, xd, 0, 0, block, tx_size);
         const SCAN_ORDER *scan_order =
             get_scan(cm, tx_size, tx_type, &xd->mi[0]->mbmi);
         const int coeff_ctx = combine_entropy_contexts(tempa[idx], templ[idy]);
@@ -4236,7 +4234,8 @@
   int64_t tmp;
   tran_low_t *const dqcoeff = BLOCK_OFFSET(pd->dqcoeff, block);
   PLANE_TYPE plane_type = get_plane_type(plane);
-  const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, block, tx_size);
+  TX_TYPE tx_type =
+      av1_get_tx_type(plane_type, xd, blk_row, blk_col, block, tx_size);
   const SCAN_ORDER *const scan_order =
       get_scan(cm, tx_size, tx_type, &xd->mi[0]->mbmi);
   BLOCK_SIZE txm_bsize = txsize_to_bsize[tx_size];
@@ -4383,7 +4382,7 @@
   RD_STATS sum_rd_stats;
 #if CONFIG_TXK_SEL
   TX_TYPE best_tx_type = TX_TYPES;
-  int txk_idx = block;
+  int txk_idx = (blk_row << 4) + blk_col;
 #endif
 
   av1_init_rd_stats(&sum_rd_stats);
@@ -4441,6 +4440,9 @@
       rd_stats->skip = 1;
       x->blk_skip[plane][blk_row * bw + blk_col] = 1;
       p->eobs[block] = 0;
+#if CONFIG_TXK_SEL
+      mbmi->txk_type[txk_idx] = DCT_DCT;
+#endif
     } else {
       x->blk_skip[plane][blk_row * bw + blk_col] = 0;
       rd_stats->skip = 0;
diff --git a/av1/encoder/tokenize.c b/av1/encoder/tokenize.c
index 23d452a..2074902 100644
--- a/av1/encoder/tokenize.c
+++ b/av1/encoder/tokenize.c
@@ -277,7 +277,8 @@
   struct macroblock_plane *p = &x->plane[plane];
   struct macroblockd_plane *pd = &xd->plane[plane];
   const PLANE_TYPE type = pd->plane_type;
-  const TX_TYPE tx_type = av1_get_tx_type(type, xd, block, tx_size);
+  const TX_TYPE tx_type =
+      av1_get_tx_type(type, xd, blk_row, blk_col, block, tx_size);
   const SCAN_ORDER *const scan_order = get_scan(cm, tx_size, tx_type, mbmi);
   const int rate = av1_cost_coeffs(
       cpi, x, plane, blk_row, blk_col, block, tx_size, scan_order,
@@ -445,7 +446,8 @@
   const int segment_id = mbmi->segment_id;
 #endif  // CONFIG_SUEPRTX
   const int16_t *scan, *nb;
-  const TX_TYPE tx_type = av1_get_tx_type(type, xd, block, tx_size);
+  const TX_TYPE tx_type =
+      av1_get_tx_type(type, xd, blk_row, blk_col, block, tx_size);
   const SCAN_ORDER *const scan_order = get_scan(cm, tx_size, tx_type, mbmi);
   const int ref = is_inter_block(mbmi);
   unsigned int(*const counts)[COEFF_CONTEXTS][ENTROPY_TOKENS] =