Fixes to make 4:1 rectangular intra work correctly

This patch fixes and enables rectangular intra transform
sizes for 4:1 partitions (that were turned off before).
4:1 partitions can now use rectangular intra predictions with
2:1 rectangular transform sizes.
BDRATE lowres (single keyframe): -0.612%

Change-Id: I6f062f7c08aae8eeb0a55d31e792c8f7e3f302a2
diff --git a/av1/common/blockd.h b/av1/common/blockd.h
index 384b59e..eb434eb 100644
--- a/av1/common/blockd.h
+++ b/av1/common/blockd.h
@@ -900,15 +900,15 @@
     0,  // BLOCK_128X64
     0,  // BLOCK_128X128
 #endif  // CONFIG_EXT_PARTITION
-    0,  // BLOCK_4X16
-    0,  // BLOCK_16X4
-    0,  // BLOCK_8X32
-    0,  // BLOCK_32X8
-    0,  // BLOCK_16X64
-    0,  // BLOCK_64X16
+    1,  // BLOCK_4X16
+    1,  // BLOCK_16X4
+    1,  // BLOCK_8X32
+    1,  // BLOCK_32X8
+    1,  // BLOCK_16X64
+    1,  // BLOCK_64X16
 #if CONFIG_EXT_PARTITION
-    0,  // BLOCK_32X128
-    0,  // BLOCK_128X32
+    1,  // BLOCK_32X128
+    1,  // BLOCK_128X32
 #endif  // CONFIG_EXT_PARTITION
   };
 
@@ -1075,15 +1075,21 @@
 
 void av1_setup_block_planes(MACROBLOCKD *xd, int ss_x, int ss_y);
 
-static INLINE int tx_size_cat_to_max_depth(int tx_size_cat) {
+static INLINE int bsize_to_max_depth(BLOCK_SIZE bsize) {
+  const int32_t tx_size_cat = intra_tx_size_cat_lookup[bsize];
   return AOMMIN(tx_size_cat + 1, MAX_TX_DEPTH);
 }
 
-static INLINE int tx_size_to_depth(TX_SIZE tx_size, int tx_size_cat) {
-  return (int)(tx_size_cat + 1 - (int)tx_size);
+static INLINE int tx_size_to_depth(TX_SIZE tx_size, BLOCK_SIZE bsize) {
+  if (tx_size == max_txsize_rect_intra_lookup[bsize]) return 0;
+  const int32_t tx_size_cat = intra_tx_size_cat_lookup[bsize];
+  const TX_SIZE coded_tx_size = txsize_sqr_map[tx_size];
+  return (int)(tx_size_cat + 1 - (int)coded_tx_size);
 }
 
-static INLINE TX_SIZE depth_to_tx_size(int depth, int tx_size_cat) {
+static INLINE TX_SIZE depth_to_tx_size(int depth, BLOCK_SIZE bsize) {
+  if (depth == 0) return max_txsize_rect_intra_lookup[bsize];
+  const int32_t tx_size_cat = intra_tx_size_cat_lookup[bsize];
   assert(tx_size_cat + 1 - depth >= 0 && tx_size_cat + 1 - depth < TX_SIZES);
   return (TX_SIZE)(tx_size_cat + 1 - depth);
 }
diff --git a/av1/common/common_data.h b/av1/common/common_data.h
index 42192fa..3a1c2a6 100644
--- a/av1/common/common_data.h
+++ b/av1/common/common_data.h
@@ -803,12 +803,17 @@
   // TODO(david.barker): Change these if we support rectangular transforms
   // for 4:1 shaped partitions
   // 4x16,            16x4,               8x32
-  TX_8X8 - TX_8X8,    TX_8X8 - TX_8X8,    TX_8X8 - TX_8X8,
+  TX_8X8 - TX_8X8,    TX_8X8 - TX_8X8,    TX_16X16 - TX_8X8,
   // 32x8,            16x64,              64x16
-  TX_8X8 - TX_8X8,    TX_16X16 - TX_8X8,  TX_16X16 - TX_8X8,
+  TX_16X16 - TX_8X8,  TX_32X32 - TX_8X8,  TX_32X32 - TX_8X8,
 #if CONFIG_EXT_PARTITION
+#if CONFIG_TX64X64
+  // 32x128,          128x32
+  TX_64X64 - TX_8X8,  TX_64X64 - TX_8X8
+#else
   // 32x128,          128x32
   TX_32X32 - TX_8X8,  TX_32X32 - TX_8X8
+#endif  // CONFIG_TX64X64
 #endif  // CONFIG_EXT_PARTITION
 };
 
diff --git a/av1/common/scan.c b/av1/common/scan.c
index 0de19ac..3250a4d 100644
--- a/av1/common/scan.c
+++ b/av1/common/scan.c
@@ -5396,10 +5396,8 @@
       // TX_4X16
       { default_scan_4x16, av1_default_iscan_4x16,
         default_scan_4x16_neighbors },
-      { default_scan_4x16, av1_default_iscan_4x16,
-        default_scan_4x16_neighbors },
-      { default_scan_4x16, av1_default_iscan_4x16,
-        default_scan_4x16_neighbors },
+      { mrow_scan_4x16, av1_mrow_iscan_4x16, mrow_scan_4x16_neighbors },
+      { mcol_scan_4x16, av1_mcol_iscan_4x16, mcol_scan_4x16_neighbors },
       { default_scan_4x16, av1_default_iscan_4x16,
         default_scan_4x16_neighbors },
       { default_scan_4x16, av1_default_iscan_4x16,
@@ -5424,10 +5422,8 @@
       // TX_16X4
       { default_scan_16x4, av1_default_iscan_16x4,
         default_scan_16x4_neighbors },
-      { default_scan_16x4, av1_default_iscan_16x4,
-        default_scan_16x4_neighbors },
-      { default_scan_16x4, av1_default_iscan_16x4,
-        default_scan_16x4_neighbors },
+      { mrow_scan_16x4, av1_mrow_iscan_16x4, mrow_scan_16x4_neighbors },
+      { mcol_scan_16x4, av1_mcol_iscan_16x4, mcol_scan_16x4_neighbors },
       { default_scan_16x4, av1_default_iscan_16x4,
         default_scan_16x4_neighbors },
       { default_scan_16x4, av1_default_iscan_16x4,
@@ -5452,10 +5448,8 @@
       // TX_8X32
       { default_scan_8x32, av1_default_iscan_8x32,
         default_scan_8x32_neighbors },
-      { default_scan_8x32, av1_default_iscan_8x32,
-        default_scan_8x32_neighbors },
-      { default_scan_8x32, av1_default_iscan_8x32,
-        default_scan_8x32_neighbors },
+      { mrow_scan_8x32, av1_mrow_iscan_8x32, mrow_scan_8x32_neighbors },
+      { mcol_scan_8x32, av1_mcol_iscan_8x32, mcol_scan_8x32_neighbors },
       { default_scan_8x32, av1_default_iscan_8x32,
         default_scan_8x32_neighbors },
       { default_scan_8x32, av1_default_iscan_8x32,
@@ -5480,10 +5474,8 @@
       // TX_32X8
       { default_scan_32x8, av1_default_iscan_32x8,
         default_scan_32x8_neighbors },
-      { default_scan_32x8, av1_default_iscan_32x8,
-        default_scan_32x8_neighbors },
-      { default_scan_32x8, av1_default_iscan_32x8,
-        default_scan_32x8_neighbors },
+      { mrow_scan_32x8, av1_mrow_iscan_32x8, mrow_scan_32x8_neighbors },
+      { mcol_scan_32x8, av1_mcol_iscan_32x8, mcol_scan_32x8_neighbors },
       { default_scan_32x8, av1_default_iscan_32x8,
         default_scan_32x8_neighbors },
       { default_scan_32x8, av1_default_iscan_32x8,
diff --git a/av1/decoder/decodemv.c b/av1/decoder/decodemv.c
index 085ddd9..51fad32 100644
--- a/av1/decoder/decodemv.c
+++ b/av1/decoder/decodemv.c
@@ -498,8 +498,13 @@
 }
 
 static TX_SIZE read_selected_tx_size(AV1_COMMON *cm, MACROBLOCKD *xd,
-                                     int32_t tx_size_cat, aom_reader *r) {
-  const int max_depths = tx_size_cat_to_max_depth(tx_size_cat);
+                                     int is_inter, aom_reader *r) {
+  // TODO(debargha): Clean up the logic here. This function should only
+  // be called for intra.
+  const BLOCK_SIZE bsize = xd->mi[0]->mbmi.sb_type;
+  const int32_t tx_size_cat = is_inter ? inter_tx_size_cat_lookup[bsize]
+                                       : intra_tx_size_cat_lookup[bsize];
+  const int max_depths = bsize_to_max_depth(bsize);
   const int ctx = get_tx_size_context(xd);
   FRAME_CONTEXT *ec_ctx = xd->tile_ctx;
   (void)cm;
@@ -507,8 +512,7 @@
   const int depth = aom_read_symbol(r, ec_ctx->tx_size_cdf[tx_size_cat][ctx],
                                     max_depths + 1, ACCT_STR);
   assert(depth >= 0 && depth <= max_depths);
-  const TX_SIZE tx_size = depth_to_tx_size(depth, tx_size_cat);
-  assert(!is_rect_tx(tx_size));
+  const TX_SIZE tx_size = depth_to_tx_size(depth, bsize);
   return tx_size;
 }
 
@@ -520,14 +524,7 @@
 
   if (block_signals_txsize(bsize)) {
     if ((!is_inter || allow_select_inter) && tx_mode == TX_MODE_SELECT) {
-      const int32_t tx_size_cat = is_inter ? inter_tx_size_cat_lookup[bsize]
-                                           : intra_tx_size_cat_lookup[bsize];
-      const TX_SIZE coded_tx_size =
-          read_selected_tx_size(cm, xd, tx_size_cat, r);
-      if (coded_tx_size > max_txsize_lookup[bsize]) {
-        assert(coded_tx_size == max_txsize_lookup[bsize] + 1);
-        return get_max_rect_tx_size(bsize, is_inter);
-      }
+      const TX_SIZE coded_tx_size = read_selected_tx_size(cm, xd, is_inter, r);
       return coded_tx_size;
     } else {
       return tx_size_from_tx_mode(bsize, tx_mode, is_inter);
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index 3b50a6b..a86cd96 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -257,13 +257,10 @@
     const TX_SIZE tx_size = mbmi->tx_size;
     const int tx_size_ctx = get_tx_size_context(xd);
     const int32_t tx_size_cat = intra_tx_size_cat_lookup[bsize];
-    const TX_SIZE coded_tx_size = txsize_sqr_up_map[tx_size];
-    const int depth = tx_size_to_depth(coded_tx_size, tx_size_cat);
-    const int max_depths = tx_size_cat_to_max_depth(tx_size_cat);
+    const int depth = tx_size_to_depth(tx_size, bsize);
+    const int max_depths = bsize_to_max_depth(bsize);
 
-    assert(coded_tx_size <= tx_size_cat + 1);
     assert(depth >= 0 && depth <= max_depths);
-
     assert(!is_inter_block(mbmi));
     assert(IMPLIES(is_rect_tx(tx_size), is_rect_tx_allowed(xd, mbmi)));
 
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index c476a92..a22217f 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -4506,9 +4506,8 @@
     const TX_SIZE tx_size = mbmi->tx_size;
     const int tx_size_ctx = get_tx_size_context(xd);
     const int32_t tx_size_cat = intra_tx_size_cat_lookup[bsize];
-    const TX_SIZE coded_tx_size = txsize_sqr_up_map[tx_size];
-    const int depth = tx_size_to_depth(coded_tx_size, tx_size_cat);
-    const int max_depths = tx_size_cat_to_max_depth(tx_size_cat);
+    const int depth = tx_size_to_depth(tx_size, bsize);
+    const int max_depths = bsize_to_max_depth(bsize);
     update_cdf(fc->tx_size_cdf[tx_size_cat][tx_size_ctx], depth,
                max_depths + 1);
   }
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 58b789f..a5677c7 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -2153,6 +2153,7 @@
   const PLANE_TYPE plane_type = get_plane_type(plane);
   const TX_TYPE tx_type =
       av1_get_tx_type(plane_type, xd, blk_row, blk_col, block, tx_size);
+
   const SCAN_ORDER *scan_order = get_scan(cm, tx_size, tx_type, mbmi);
   this_rd_stats.rate =
       av1_cost_coeffs(cpi, x, plane, blk_row, blk_col, block, tx_size,
@@ -2326,8 +2327,7 @@
     const int is_inter = is_inter_block(mbmi);
     const int32_t tx_size_cat = is_inter ? inter_tx_size_cat_lookup[bsize]
                                          : intra_tx_size_cat_lookup[bsize];
-    const TX_SIZE coded_tx_size = txsize_sqr_up_map[tx_size];
-    const int depth = tx_size_to_depth(coded_tx_size, tx_size_cat);
+    const int depth = tx_size_to_depth(tx_size, bsize);
     const int tx_size_ctx = get_tx_size_context(xd);
     int r_tx_size = x->tx_size_cost[tx_size_cat][tx_size_ctx][depth];
     return r_tx_size;
@@ -3627,7 +3627,8 @@
                       tx_size, a, l, 0, rd_stats);
   return;
 #endif
-
+  // This function is used only for inter
+  assert(is_inter_block(&xd->mi[0]->mbmi));
   int64_t tmp;
   tran_low_t *const dqcoeff = BLOCK_OFFSET(pd->dqcoeff, block);
   PLANE_TYPE plane_type = get_plane_type(plane);