Refactor av1_get_tx_type()

Change-Id: I1f8332924f5e49d238454bd099a51f9b7a7cdd8c
diff --git a/av1/common/blockd.h b/av1/common/blockd.h
index 5019ca6..bbee949 100644
--- a/av1/common/blockd.h
+++ b/av1/common/blockd.h
@@ -858,20 +858,19 @@
                                       int blk_col, TX_SIZE tx_size,
                                       int reduced_tx_set) {
   const MB_MODE_INFO *const mbmi = xd->mi[0];
-  const struct macroblockd_plane *const pd = &xd->plane[plane_type];
-  const TxSetType tx_set_type =
-      av1_get_ext_tx_set_type(tx_size, is_inter_block(mbmi), reduced_tx_set);
+  if (xd->lossless[mbmi->segment_id] || txsize_sqr_up_map[tx_size] > TX_32X32) {
+    return DCT_DCT;
+  }
 
   TX_TYPE tx_type;
-  if (xd->lossless[mbmi->segment_id] || txsize_sqr_up_map[tx_size] > TX_32X32) {
-    tx_type = DCT_DCT;
+  if (plane_type == PLANE_TYPE_Y) {
+    const int txk_type_idx =
+        av1_get_txk_type_index(mbmi->sb_type, blk_row, blk_col);
+    tx_type = mbmi->txk_type[txk_type_idx];
   } else {
-    if (plane_type == PLANE_TYPE_Y) {
-      const int txk_type_idx =
-          av1_get_txk_type_index(mbmi->sb_type, blk_row, blk_col);
-      tx_type = mbmi->txk_type[txk_type_idx];
-    } else if (is_inter_block(mbmi)) {
+    if (is_inter_block(mbmi)) {
       // scale back to y plane's coordinate
+      const struct macroblockd_plane *const pd = &xd->plane[plane_type];
       blk_row <<= pd->subsampling_y;
       blk_col <<= pd->subsampling_x;
       const int txk_type_idx =
@@ -882,9 +881,13 @@
       // plane, so the tx_type should not be shared
       tx_type = intra_mode_to_tx_type(mbmi, PLANE_TYPE_UV);
     }
+    const TxSetType tx_set_type =
+        av1_get_ext_tx_set_type(tx_size, is_inter_block(mbmi), reduced_tx_set);
+    if (!av1_ext_tx_used[tx_set_type][tx_type]) tx_type = DCT_DCT;
   }
   assert(tx_type < TX_TYPES);
-  if (!av1_ext_tx_used[tx_set_type][tx_type]) return DCT_DCT;
+  assert(av1_ext_tx_used[av1_get_ext_tx_set_type(tx_size, is_inter_block(mbmi),
+                                                 reduced_tx_set)][tx_type]);
   return tx_type;
 }