Refactor search_txk_type

1. Reduce the loops of uv_plane since tx_type should be the same as y_plane
2. Don't check ref_tx_type if allowed_tx_mask is already 0

Change-Id: I495d5db58fedcf72a888f75e70b8bd7530a5a194
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 9a2788b..84b4bc2 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -1979,6 +1979,14 @@
 
   int allowed_tx_mask[TX_TYPES] = { 0 };  // 1: allow; 0: skip.
   int allowed_tx_num = 0;
+  TX_TYPE uv_tx_type = DCT_DCT;
+  if (plane) {
+    // tx_type of PLANE_TYPE_UV should be the same as PLANE_TYPE_Y
+    uv_tx_type = txk_start = txk_end =
+        av1_get_tx_type(get_plane_type(plane), xd, blk_row, blk_col, tx_size,
+                        cm->reduced_tx_set_used);
+  }
+
   for (TX_TYPE tx_type = txk_start; tx_type <= txk_end; ++tx_type) {
     allowed_tx_mask[tx_type] = 1;
     if (do_prune) {
@@ -1995,20 +2003,26 @@
       if (is_inter && x->use_default_inter_tx_type &&
           tx_type != get_default_tx_type(0, xd, tx_size))
         allowed_tx_mask[tx_type] = 0;
-      mbmi->txk_type[txk_type_idx] = tx_type;
+
+      if (allowed_tx_mask[tx_type]) {
+        mbmi->txk_type[txk_type_idx] = tx_type;
+        const TX_TYPE ref_tx_type =
+            av1_get_tx_type(get_plane_type(plane), xd, blk_row, blk_col,
+                            tx_size, cm->reduced_tx_set_used);
+        if (tx_type != ref_tx_type) {
+          // use av1_get_tx_type() to check if the tx_type is valid for the
+          // current mode if it's not, we skip it here.
+          allowed_tx_mask[tx_type] = 0;
+        }
+      }
     }
-    const TX_TYPE ref_tx_type =
-        av1_get_tx_type(get_plane_type(plane), xd, blk_row, blk_col, tx_size,
-                        cm->reduced_tx_set_used);
-    if (tx_type != ref_tx_type) {
-      // use av1_get_tx_type() to check if the tx_type is valid for the current
-      // mode if it's not, we skip it here.
-      allowed_tx_mask[tx_type] = 0;
-    }
+
     allowed_tx_num += allowed_tx_mask[tx_type];
   }
   // Need to have at least one transform type allowed.
-  if (allowed_tx_num == 0) allowed_tx_mask[DCT_DCT] = 1;
+  if (allowed_tx_num == 0) {
+    allowed_tx_mask[plane ? uv_tx_type : DCT_DCT] = 1;
+  }
 
   int use_transform_domain_distortion =
       cpi->sf.use_transform_domain_distortion &&