Enable rectangular transforms for Intra also.

These are under EXT_TX + RECT_TX experiment combo.

Results
=======

Derf Set:
--------
All Intra frames: 1.8% avg improvement (and 1.78% BD-rate improvement)
Video: 0.230% avg improvement (and 0.262% BD-rate improvement)

Objective-1-fast set
--------------------
Video: 0.52 PSNR improvement

Change-Id: I1893465929858e38419f327752dc61c19b96b997
diff --git a/av1/decoder/decodemv.c b/av1/decoder/decodemv.c
index 7c59417..84de6aa 100644
--- a/av1/decoder/decodemv.c
+++ b/av1/decoder/decodemv.c
@@ -384,45 +384,32 @@
                                      int tx_size_cat, aom_reader *r) {
   FRAME_COUNTS *counts = xd->counts;
   const int ctx = get_tx_size_context(xd);
-  int depth = aom_read_tree(r, av1_tx_size_tree[tx_size_cat],
-                            cm->fc->tx_size_probs[tx_size_cat][ctx], ACCT_STR);
-  TX_SIZE tx_size = depth_to_tx_size(depth);
+  const int depth =
+      aom_read_tree(r, av1_tx_size_tree[tx_size_cat],
+                    cm->fc->tx_size_probs[tx_size_cat][ctx], ACCT_STR);
+  const TX_SIZE tx_size = depth_to_tx_size(depth);
+#if CONFIG_RECT_TX
+  assert(!is_rect_tx(tx_size));
+#endif  // CONFIG_RECT_TX
   if (counts) ++counts->tx_size[tx_size_cat][ctx][depth];
   return tx_size;
 }
 
-static TX_SIZE read_tx_size_intra(AV1_COMMON *cm, MACROBLOCKD *xd,
-                                  aom_reader *r) {
-  TX_MODE tx_mode = cm->tx_mode;
-  BLOCK_SIZE bsize = xd->mi[0]->mbmi.sb_type;
-  if (xd->lossless[xd->mi[0]->mbmi.segment_id]) return TX_4X4;
-  if (bsize >= BLOCK_8X8) {
-    if (tx_mode == TX_MODE_SELECT) {
-      const TX_SIZE tx_size =
-          read_selected_tx_size(cm, xd, intra_tx_size_cat_lookup[bsize], r);
-      assert(tx_size <= max_txsize_lookup[bsize]);
-      return tx_size;
-    } else {
-      return tx_size_from_tx_mode(bsize, cm->tx_mode, 0);
-    }
-  } else {
-    return TX_4X4;
-  }
-}
-
-static TX_SIZE read_tx_size_inter(AV1_COMMON *cm, MACROBLOCKD *xd,
-                                  int allow_select, aom_reader *r) {
-  TX_MODE tx_mode = cm->tx_mode;
-  BLOCK_SIZE bsize = xd->mi[0]->mbmi.sb_type;
+static TX_SIZE read_tx_size(AV1_COMMON *cm, MACROBLOCKD *xd, int is_inter,
+                            int allow_select_inter, aom_reader *r) {
+  const TX_MODE tx_mode = cm->tx_mode;
+  const BLOCK_SIZE bsize = xd->mi[0]->mbmi.sb_type;
   if (xd->lossless[xd->mi[0]->mbmi.segment_id]) return TX_4X4;
 #if CONFIG_CB4X4 && CONFIG_VAR_TX
   if (bsize > BLOCK_4X4) {
 #else
   if (bsize >= BLOCK_8X8) {
-#endif
-    if (allow_select && tx_mode == TX_MODE_SELECT) {
+#endif  // CONFIG_CB4X4 && CONFIG_VAR_TX
+    if ((!is_inter || allow_select_inter) && tx_mode == TX_MODE_SELECT) {
+      const int32_t tx_size_cat = is_inter ? inter_tx_size_cat_lookup[bsize]
+                                           : intra_tx_size_cat_lookup[bsize];
       const TX_SIZE coded_tx_size =
-          read_selected_tx_size(cm, xd, inter_tx_size_cat_lookup[bsize], r);
+          read_selected_tx_size(cm, xd, tx_size_cat, r);
 #if CONFIG_EXT_TX && CONFIG_RECT_TX
       if (coded_tx_size > max_txsize_lookup[bsize]) {
         assert(coded_tx_size == max_txsize_lookup[bsize] + 1);
@@ -433,7 +420,7 @@
 #endif  // CONFIG_EXT_TX && CONFIG_RECT_TX
       return coded_tx_size;
     } else {
-      return tx_size_from_tx_mode(bsize, cm->tx_mode, 1);
+      return tx_size_from_tx_mode(bsize, tx_mode, is_inter);
     }
   } else {
 #if CONFIG_EXT_TX && CONFIG_RECT_TX
@@ -441,7 +428,7 @@
     return max_txsize_rect_lookup[bsize];
 #else
     return TX_4X4;
-#endif
+#endif  // CONFIG_EXT_TX && CONFIG_RECT_TX
   }
 }
 
@@ -711,6 +698,7 @@
 #endif
   if (!FIXED_TX_TYPE) {
 #if CONFIG_EXT_TX
+    const TX_SIZE square_tx_size = txsize_sqr_map[tx_size];
     if (get_ext_tx_types(tx_size, mbmi->sb_type, inter_block) > 1 &&
         cm->base_qindex > 0 && !mbmi->skip &&
 #if CONFIG_SUPERTX
@@ -724,19 +712,19 @@
         if (eset > 0) {
           mbmi->tx_type = aom_read_tree(
               r, av1_ext_tx_inter_tree[eset],
-              cm->fc->inter_ext_tx_prob[eset][txsize_sqr_map[tx_size]],
-              ACCT_STR);
+              cm->fc->inter_ext_tx_prob[eset][square_tx_size], ACCT_STR);
           if (counts)
-            ++counts->inter_ext_tx[eset][txsize_sqr_map[tx_size]]
-                                  [mbmi->tx_type];
+            ++counts->inter_ext_tx[eset][square_tx_size][mbmi->tx_type];
         }
       } else if (ALLOW_INTRA_EXT_TX) {
         if (eset > 0) {
           mbmi->tx_type = aom_read_tree(
               r, av1_ext_tx_intra_tree[eset],
-              cm->fc->intra_ext_tx_prob[eset][tx_size][mbmi->mode], ACCT_STR);
+              cm->fc->intra_ext_tx_prob[eset][square_tx_size][mbmi->mode],
+              ACCT_STR);
           if (counts)
-            ++counts->intra_ext_tx[eset][tx_size][mbmi->mode][mbmi->tx_type];
+            ++counts->intra_ext_tx[eset][square_tx_size][mbmi->mode]
+                                  [mbmi->tx_type];
         }
       }
     } else {
@@ -807,7 +795,7 @@
   }
 #endif
 
-  mbmi->tx_size = read_tx_size_intra(cm, xd, r);
+  mbmi->tx_size = read_tx_size(cm, xd, 0, 1, r);
   mbmi->ref_frame[0] = INTRA_FRAME;
   mbmi->ref_frame[1] = NONE;
 
@@ -1967,10 +1955,7 @@
           read_tx_size_vartx(cm, xd, mbmi, xd->counts, max_tx_size,
                              height != width, idy, idx, r);
     } else {
-      if (inter_block)
-        mbmi->tx_size = read_tx_size_inter(cm, xd, !mbmi->skip, r);
-      else
-        mbmi->tx_size = read_tx_size_intra(cm, xd, r);
+      mbmi->tx_size = read_tx_size(cm, xd, inter_block, !mbmi->skip, r);
 
       if (inter_block) {
         const int width = block_size_wide[bsize] >> tx_size_wide_log2[0];
@@ -1984,10 +1969,7 @@
       set_txfm_ctxs(mbmi->tx_size, xd->n8_w, xd->n8_h, mbmi->skip, xd);
     }
 #else
-  if (inter_block)
-    mbmi->tx_size = read_tx_size_inter(cm, xd, !mbmi->skip, r);
-  else
-    mbmi->tx_size = read_tx_size_intra(cm, xd, r);
+  mbmi->tx_size = read_tx_size(cm, xd, inter_block, !mbmi->skip, r);
 #endif  // CONFIG_VAR_TX
 #if CONFIG_SUPERTX
   }