refactor get_tx_type()

Change-Id: I2888bd8905253e02e3ac74597275cf56e5142d29
diff --git a/av1/common/blockd.h b/av1/common/blockd.h
index e0c803e..6b5d267 100644
--- a/av1/common/blockd.h
+++ b/av1/common/blockd.h
@@ -1113,8 +1113,9 @@
                                            : mbmi->uv_mode];
 }
 
-static INLINE TX_TYPE get_tx_type(PLANE_TYPE plane_type, const MACROBLOCKD *xd,
-                                  int block, TX_SIZE tx_size) {
+static INLINE TX_TYPE av1_get_tx_type(PLANE_TYPE plane_type,
+                                      const MACROBLOCKD *xd, int block,
+                                      TX_SIZE tx_size) {
   const MODE_INFO *const mi = xd->mi[0];
   const MB_MODE_INFO *const mbmi = &mi->mbmi;
 
@@ -1122,15 +1123,25 @@
   // TODO(aconverse@google.com): Handle INTRABC + EXT_TX + TXK_SEL
   if (is_intrabc_block(mbmi)) return DCT_DCT;
 #endif  // CONFIG_INTRABC && (!CONFIG_EXT_TX || CONFIG_TXK_SEL)
-#if !CONFIG_TXK_SEL
+
+#if CONFIG_TXK_SEL
+  TX_TYPE tx_type;
+  if (plane_type != PLANE_TYPE_Y || xd->lossless[mbmi->segment_id] ||
+      txsize_sqr_map[tx_size] >= TX_32X32) {
+    tx_type = DCT_DCT;
+  } else {
+    tx_type = mbmi->txk_type[block];
+  }
+  assert(tx_type >= DCT_DCT && tx_type < TX_TYPES);
+  return tx_type;
+#endif  // CONFIG_TXK_SEL
+
 #if FIXED_TX_TYPE
-  (void)mbmi;
   const int block_raster_idx = av1_block_index_to_raster_order(tx_size, block);
   return get_default_tx_type(plane_type, xd, block_raster_idx, tx_size);
-#elif CONFIG_EXT_TX
-#if !CONFIG_CB4X4
-  const int block_raster_idx = av1_block_index_to_raster_order(tx_size, block);
-#endif  // !CONFIG_CB4X4
+#endif  // FIXED_TX_TYPE
+
+#if CONFIG_EXT_TX
   if (xd->lossless[mbmi->segment_id] || txsize_sqr_map[tx_size] > TX_32X32 ||
       (txsize_sqr_map[tx_size] >= TX_32X32 && !is_inter_block(mbmi)))
     return DCT_DCT;
@@ -1159,17 +1170,19 @@
   if (tx_size < TX_4X4)
     return DCT_DCT;
   else
-#endif
+#endif  // CONFIG_CHROMA_2X2
     return intra_mode_to_tx_type_context[mbmi->uv_mode];
 #else   // CONFIG_CB4X4
-
   // Sub8x8-Inter/Intra OR UV-Intra
-  if (is_inter_block(mbmi))  // Sub8x8-Inter
+  if (is_inter_block(mbmi)) {  // Sub8x8-Inter
     return DCT_DCT;
-  else  // Sub8x8 Intra OR UV-Intra
+  } else {  // Sub8x8 Intra OR UV-Intra
+    const int block_raster_idx =
+        av1_block_index_to_raster_order(tx_size, block);
     return intra_mode_to_tx_type_context[plane_type == PLANE_TYPE_Y
                                              ? get_y_mode(mi, block_raster_idx)
                                              : mbmi->uv_mode];
+  }
 #endif  // CONFIG_CB4X4
 #else   // CONFIG_EXT_TX
   (void)block;
@@ -1178,18 +1191,6 @@
     return DCT_DCT;
   return mbmi->tx_type;
 #endif  // CONFIG_EXT_TX
-#else   // !CONFIG_TXK_SEL
-  (void)tx_size;
-  TX_TYPE tx_type;
-  if (plane_type != PLANE_TYPE_Y || xd->lossless[mbmi->segment_id] ||
-      txsize_sqr_map[tx_size] >= TX_32X32) {
-    tx_type = DCT_DCT;
-  } else {
-    tx_type = mbmi->txk_type[block];
-  }
-  assert(tx_type >= DCT_DCT && tx_type < TX_TYPES);
-  return tx_type;
-#endif  // !CONFIG_TXK_SEL
 }
 
 void av1_setup_block_planes(MACROBLOCKD *xd, int ss_x, int ss_y);
diff --git a/av1/common/idct.c b/av1/common/idct.c
index 80a671d..ff8ba66 100644
--- a/av1/common/idct.c
+++ b/av1/common/idct.c
@@ -2103,7 +2103,7 @@
   tran_low_t *dqcoeff = BLOCK_OFFSET(pd->dqcoeff, block);
   const PLANE_TYPE plane_type = get_plane_type(plane);
   const TX_SIZE tx_size = get_tx_size(plane, xd);
-  const TX_TYPE tx_type = get_tx_type(plane_type, xd, block, tx_size);
+  const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, block, tx_size);
   const int dst_stride = pd->dst.stride;
   uint8_t *dst =
       &pd->dst.buf[(blk_row * dst_stride + blk_col) << tx_size_wide_log2[0]];
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index dac6e6a..f608d41 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -621,9 +621,9 @@
     av1_read_coeffs_txb_facade(cm, xd, r, row, col, block_idx, plane,
                                pd->dqcoeff, tx_size, &max_scan_line, &eob);
     // tx_type will be read out in av1_read_coeffs_txb_facade
-    TX_TYPE tx_type = get_tx_type(plane_type, xd, block_idx, tx_size);
+    const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, block_idx, tx_size);
 #else   // CONFIG_LV_MAP
-    TX_TYPE tx_type = get_tx_type(plane_type, xd, block_idx, tx_size);
+    const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, block_idx, tx_size);
     const SCAN_ORDER *scan_order = get_scan(cm, tx_size, tx_type, mbmi);
     int16_t max_scan_line = 0;
     const int eob =
@@ -654,7 +654,7 @@
 #endif  // CONFIG_DPCM_INTRA
     }
 #else   // !CONFIG_PVQ
-    TX_TYPE tx_type = get_tx_type(plane_type, xd, block_idx, tx_size);
+    const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, block_idx, tx_size);
     av1_pvq_decode_helper2(cm, xd, mbmi, plane, row, col, tx_size, tx_type);
 #endif  // !CONFIG_PVQ
   }
@@ -703,9 +703,11 @@
     av1_read_coeffs_txb_facade(cm, xd, r, blk_row, blk_col, block, plane,
                                pd->dqcoeff, tx_size, &max_scan_line, &eob);
     // tx_type will be read out in av1_read_coeffs_txb_facade
-    TX_TYPE tx_type = get_tx_type(plane_type, xd, block, plane_tx_size);
+    const TX_TYPE tx_type =
+        av1_get_tx_type(plane_type, xd, block, plane_tx_size);
 #else   // CONFIG_LV_MAP
-    TX_TYPE tx_type = get_tx_type(plane_type, xd, block, plane_tx_size);
+    const TX_TYPE tx_type =
+        av1_get_tx_type(plane_type, xd, block, plane_tx_size);
     const SCAN_ORDER *sc = get_scan(cm, plane_tx_size, tx_type, mbmi);
     int16_t max_scan_line = 0;
     const int eob = av1_decode_block_tokens(
@@ -768,10 +770,10 @@
   av1_read_coeffs_txb_facade(cm, xd, r, row, col, block_idx, plane, pd->dqcoeff,
                              tx_size, &max_scan_line, &eob);
   // tx_type will be read out in av1_read_coeffs_txb_facade
-  TX_TYPE tx_type = get_tx_type(plane_type, xd, block_idx, tx_size);
+  const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, block_idx, tx_size);
 #else   // CONFIG_LV_MAP
   int16_t max_scan_line = 0;
-  TX_TYPE tx_type = get_tx_type(plane_type, xd, block_idx, tx_size);
+  const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, block_idx, tx_size);
   const SCAN_ORDER *scan_order =
       get_scan(cm, tx_size, tx_type, &xd->mi[0]->mbmi);
   const int eob =
@@ -788,7 +790,7 @@
                             tx_type, tx_size, dst, pd->dst.stride,
                             max_scan_line, eob);
 #else
-  TX_TYPE tx_type = get_tx_type(plane_type, xd, block_idx, tx_size);
+  const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, block_idx, tx_size);
   eob = av1_pvq_decode_helper2(cm, xd, &xd->mi[0]->mbmi, plane, row, col,
                                tx_size, tx_type);
 #endif
diff --git a/av1/decoder/decodetxb.c b/av1/decoder/decodetxb.c
index d0d79f6..ec87241 100644
--- a/av1/decoder/decodetxb.c
+++ b/av1/decoder/decodetxb.c
@@ -82,7 +82,7 @@
 #if CONFIG_TXK_SEL
   av1_read_tx_type(cm, xd, block, plane, get_min_tx_size(tx_size), r);
 #endif
-  TX_TYPE tx_type = get_tx_type(plane_type, xd, block, tx_size);
+  const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, block, tx_size);
   const SCAN_ORDER *const scan_order = get_scan(cm, tx_size, tx_type, mbmi);
   const int16_t *scan = scan_order->scan;
 
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index 384465a..d1a53d2 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -1572,7 +1572,7 @@
   // Only y plane's tx_type is transmitted
   if (plane > 0) return;
   PLANE_TYPE plane_type = get_plane_type(plane);
-  TX_TYPE tx_type = get_tx_type(plane_type, xd, block, tx_size);
+  const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, block, tx_size);
 #endif
 
   if (!FIXED_TX_TYPE) {
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index f3c3098..4870a68 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -5898,7 +5898,7 @@
   (void)blk_col;
   // Only y plane's tx_type is updated
   if (plane > 0) return;
-  TX_TYPE tx_type = get_tx_type(PLANE_TYPE_Y, xd, block, tx_size);
+  const TX_TYPE tx_type = av1_get_tx_type(PLANE_TYPE_Y, xd, block, tx_size);
 #endif
 #if CONFIG_EXT_TX
   if (get_ext_tx_types(tx_size, bsize, is_inter, cm->reduced_tx_set_used) > 1 &&
diff --git a/av1/encoder/encodemb.c b/av1/encoder/encodemb.c
index 7e7d390..be46248 100644
--- a/av1/encoder/encodemb.c
+++ b/av1/encoder/encodemb.c
@@ -138,7 +138,7 @@
   const PLANE_TYPE plane_type = pd->plane_type;
   const int16_t *const dequant_ptr = pd->dequant;
   const uint8_t *const band_translate = get_band_translate(tx_size);
-  TX_TYPE tx_type = get_tx_type(plane_type, xd, block, tx_size);
+  const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, block, tx_size);
   const SCAN_ORDER *const scan_order =
       get_scan(cm, tx_size, tx_type, &xd->mi[0]->mbmi);
   const int16_t *const scan = scan_order->scan;
@@ -508,7 +508,7 @@
   struct macroblockd_plane *const pd = &xd->plane[plane];
 #endif
   PLANE_TYPE plane_type = get_plane_type(plane);
-  TX_TYPE tx_type = get_tx_type(plane_type, xd, block, tx_size);
+  const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, block, tx_size);
 
 #if CONFIG_AOM_QM || CONFIG_NEW_QUANT
   const int is_inter = is_inter_block(mbmi);
@@ -753,7 +753,7 @@
 
   if (x->pvq_skip[plane]) return;
 #endif
-  TX_TYPE tx_type = get_tx_type(pd->plane_type, xd, block, tx_size);
+  const TX_TYPE tx_type = av1_get_tx_type(pd->plane_type, xd, block, tx_size);
 #if CONFIG_LGT
   PREDICTION_MODE mode = get_prediction_mode(xd->mi[0], plane, tx_size, block);
   av1_inverse_transform_block(xd, dqcoeff, mode, tx_type, tx_size, dst,
@@ -1346,7 +1346,7 @@
   struct macroblockd_plane *const pd = &xd->plane[plane];
   tran_low_t *dqcoeff = BLOCK_OFFSET(pd->dqcoeff, block);
   PLANE_TYPE plane_type = get_plane_type(plane);
-  const TX_TYPE tx_type = get_tx_type(plane_type, xd, block, tx_size);
+  const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, block, tx_size);
   uint16_t *eob = &p->eobs[block];
   const int dst_stride = pd->dst.stride;
   uint8_t *dst =
diff --git a/av1/encoder/encodetxb.c b/av1/encoder/encodetxb.c
index 0376871..13e2152 100644
--- a/av1/encoder/encodetxb.c
+++ b/av1/encoder/encodetxb.c
@@ -78,7 +78,7 @@
   MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
   const PLANE_TYPE plane_type = get_plane_type(plane);
   const TX_SIZE txs_ctx = get_txsize_context(tx_size);
-  const TX_TYPE tx_type = get_tx_type(plane_type, xd, block, tx_size);
+  const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, block, tx_size);
   const SCAN_ORDER *const scan_order = get_scan(cm, tx_size, tx_type, mbmi);
   const int16_t *scan = scan_order->scan;
   int c;
@@ -295,7 +295,7 @@
   const PLANE_TYPE plane_type = get_plane_type(plane);
   (void)blk_row;
   (void)blk_col;
-  const TX_TYPE tx_type = get_tx_type(plane_type, xd, block, tx_size);
+  const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, block, tx_size);
   MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
   const struct macroblock_plane *p = &x->plane[plane];
   const int eob = p->eobs[block];
@@ -1473,7 +1473,7 @@
   const TX_SIZE txs_ctx = get_txsize_context(tx_size);
   (void)blk_row;
   (void)blk_col;
-  const TX_TYPE tx_type = get_tx_type(plane_type, xd, block, tx_size);
+  const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, block, tx_size);
   const MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
   const struct macroblock_plane *p = &x->plane[plane];
   struct macroblockd_plane *pd = &xd->plane[plane];
@@ -1548,7 +1548,7 @@
   const uint16_t eob = p->eobs[block];
   const tran_low_t *qcoeff = BLOCK_OFFSET(p->qcoeff, block);
   const PLANE_TYPE plane_type = pd->plane_type;
-  const TX_TYPE tx_type = get_tx_type(plane_type, xd, block, tx_size);
+  const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, block, tx_size);
   const SCAN_ORDER *const scan_order = get_scan(cm, tx_size, tx_type, mbmi);
   (void)plane_bsize;
 
@@ -1573,7 +1573,7 @@
   const tran_low_t *qcoeff = BLOCK_OFFSET(p->qcoeff, block);
   tran_low_t *tcoeff = BLOCK_OFFSET(x->mbmi_ext->tcoeff[plane], block);
   const int segment_id = mbmi->segment_id;
-  const TX_TYPE tx_type = get_tx_type(plane_type, xd, block, tx_size);
+  const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, block, tx_size);
   const SCAN_ORDER *const scan_order = get_scan(cm, tx_size, tx_type, mbmi);
   const int16_t *scan = scan_order->scan;
   const int seg_eob = get_tx_eob(&cpi->common.seg, segment_id, tx_size);
@@ -1883,11 +1883,11 @@
 
   for (tx_type = txk_start; tx_type <= txk_end; ++tx_type) {
     if (plane == 0) mbmi->txk_type[block] = tx_type;
-    TX_TYPE ref_tx_type =
-        get_tx_type(get_plane_type(plane), xd, block, tx_size);
+    const TX_TYPE ref_tx_type =
+        av1_get_tx_type(get_plane_type(plane), xd, block, tx_size);
     if (tx_type != ref_tx_type) {
-      // use get_tx_type() to check if the tx_type is valid for the current mode
-      // if it's not, we skip it here.
+      // use av1_get_tx_type() to check if the tx_type is valid for the current
+      // mode if it's not, we skip it here.
       continue;
     }
 
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index b4903b4..8cb1b97 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -1745,7 +1745,7 @@
 #endif  // !CONFIG_PVQ
 
         const PLANE_TYPE plane_type = get_plane_type(plane);
-        TX_TYPE tx_type = get_tx_type(plane_type, xd, block, tx_size);
+        const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, block, tx_size);
 
         av1_inverse_transform_block(xd, dqcoeff,
 #if CONFIG_LGT
@@ -1821,8 +1821,8 @@
         av1_block_index_to_raster_order(tx_size, block);
     const PREDICTION_MODE mode =
         (plane == 0) ? get_y_mode(xd->mi[0], block_raster_idx) : mbmi->uv_mode;
-    TX_TYPE tx_type = get_tx_type((plane == 0) ? PLANE_TYPE_Y : PLANE_TYPE_UV,
-                                  xd, block, tx_size);
+    const TX_TYPE tx_type = av1_get_tx_type(
+        (plane == 0) ? PLANE_TYPE_Y : PLANE_TYPE_UV, xd, block, tx_size);
     if (av1_use_dpcm_intra(plane, mode, tx_type, mbmi)) {
       int8_t skip;
       av1_encode_block_intra_dpcm(cm, x, mode, plane, block, blk_row, blk_col,
@@ -1882,7 +1882,7 @@
   }
 #if !CONFIG_PVQ
   const PLANE_TYPE plane_type = get_plane_type(plane);
-  const TX_TYPE tx_type = get_tx_type(plane_type, xd, block, tx_size);
+  const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, block, tx_size);
   const SCAN_ORDER *scan_order = get_scan(cm, tx_size, tx_type, mbmi);
   this_rd_stats.rate =
       av1_cost_coeffs(cpi, x, plane, blk_row, blk_col, block, tx_size,
@@ -3093,7 +3093,8 @@
                                     src_stride, dst, dst_stride, xd->bd);
 #endif
           if (is_lossless) {
-            TX_TYPE tx_type = get_tx_type(PLANE_TYPE_Y, xd, block, tx_size);
+            const TX_TYPE tx_type =
+                av1_get_tx_type(PLANE_TYPE_Y, xd, block, tx_size);
             const SCAN_ORDER *scan_order =
                 get_scan(cm, tx_size, tx_type, &xd->mi[0]->mbmi);
             const int coeff_ctx =
@@ -3141,7 +3142,8 @@
           } else {
             int64_t dist;
             unsigned int tmp;
-            TX_TYPE tx_type = get_tx_type(PLANE_TYPE_Y, xd, block, tx_size);
+            const TX_TYPE tx_type =
+                av1_get_tx_type(PLANE_TYPE_Y, xd, block, tx_size);
             const SCAN_ORDER *scan_order =
                 get_scan(cm, tx_size, tx_type, &xd->mi[0]->mbmi);
             const int coeff_ctx =
@@ -3300,7 +3302,8 @@
                            dst, dst_stride);
 #endif  // !CONFIG_PVQ
 
-        TX_TYPE tx_type = get_tx_type(PLANE_TYPE_Y, xd, block, tx_size);
+        const TX_TYPE tx_type =
+            av1_get_tx_type(PLANE_TYPE_Y, xd, block, tx_size);
         const SCAN_ORDER *scan_order =
             get_scan(cm, tx_size, tx_type, &xd->mi[0]->mbmi);
         const int coeff_ctx = combine_entropy_contexts(tempa[idx], templ[idy]);
@@ -4233,7 +4236,7 @@
   int64_t tmp;
   tran_low_t *const dqcoeff = BLOCK_OFFSET(pd->dqcoeff, block);
   PLANE_TYPE plane_type = get_plane_type(plane);
-  TX_TYPE tx_type = get_tx_type(plane_type, xd, block, tx_size);
+  const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, block, tx_size);
   const SCAN_ORDER *const scan_order =
       get_scan(cm, tx_size, tx_type, &xd->mi[0]->mbmi);
   BLOCK_SIZE txm_bsize = txsize_to_bsize[tx_size];
diff --git a/av1/encoder/tokenize.c b/av1/encoder/tokenize.c
index 5301a95..23d452a 100644
--- a/av1/encoder/tokenize.c
+++ b/av1/encoder/tokenize.c
@@ -277,7 +277,7 @@
   struct macroblock_plane *p = &x->plane[plane];
   struct macroblockd_plane *pd = &xd->plane[plane];
   const PLANE_TYPE type = pd->plane_type;
-  const TX_TYPE tx_type = get_tx_type(type, xd, block, tx_size);
+  const TX_TYPE tx_type = av1_get_tx_type(type, xd, block, tx_size);
   const SCAN_ORDER *const scan_order = get_scan(cm, tx_size, tx_type, mbmi);
   const int rate = av1_cost_coeffs(
       cpi, x, plane, blk_row, blk_col, block, tx_size, scan_order,
@@ -445,7 +445,7 @@
   const int segment_id = mbmi->segment_id;
 #endif  // CONFIG_SUEPRTX
   const int16_t *scan, *nb;
-  const TX_TYPE tx_type = get_tx_type(type, xd, block, tx_size);
+  const TX_TYPE tx_type = av1_get_tx_type(type, xd, block, tx_size);
   const SCAN_ORDER *const scan_order = get_scan(cm, tx_size, tx_type, mbmi);
   const int ref = is_inter_block(mbmi);
   unsigned int(*const counts)[COEFF_CONTEXTS][ENTROPY_TOKENS] =