Misc refactors to support 4:1->2:1->1:1 tx splits

Currently 4:1 transforms have max 2 split levels:
4:1 -> 1:1 -> 0.5:0.5.

This refactor enables split levels:
4:1 -> 2:1 -> 1:1,

by simply changing the tables in common_data.h.

The actual switch will be made in a subsequent patch.

Change-Id: I33f8d9ca5159ba3e7d02ced449ddf6f804a8f12a
diff --git a/av1/common/blockd.h b/av1/common/blockd.h
index c9b54ba..e78ade1 100644
--- a/av1/common/blockd.h
+++ b/av1/common/blockd.h
@@ -1105,23 +1105,34 @@
 
 void av1_setup_block_planes(MACROBLOCKD *xd, int ss_x, int ss_y);
 
-static INLINE int bsize_to_max_depth(BLOCK_SIZE bsize) {
-  const int32_t tx_size_cat = intra_tx_size_cat_lookup[bsize];
-  return AOMMIN(tx_size_cat + 1, MAX_TX_DEPTH);
+static INLINE int bsize_to_max_depth(BLOCK_SIZE bsize, int is_inter) {
+  TX_SIZE tx_size = get_max_rect_tx_size(bsize, is_inter);
+  int depth = 0;
+  while (depth < MAX_TX_DEPTH && tx_size != TX_4X4) {
+    depth++;
+    tx_size = sub_tx_size_map[tx_size];
+  }
+  return depth;
 }
 
-static INLINE int tx_size_to_depth(TX_SIZE tx_size, BLOCK_SIZE bsize) {
-  if (tx_size == max_txsize_rect_intra_lookup[bsize]) return 0;
-  const int32_t tx_size_cat = intra_tx_size_cat_lookup[bsize];
-  const TX_SIZE coded_tx_size = txsize_sqr_map[tx_size];
-  return (int)(tx_size_cat + 1 - (int)coded_tx_size);
+static INLINE int tx_size_to_depth(TX_SIZE tx_size, BLOCK_SIZE bsize,
+                                   int is_inter) {
+  TX_SIZE ctx_size = get_max_rect_tx_size(bsize, is_inter);
+  int depth = 0;
+  while (tx_size != ctx_size) {
+    depth++;
+    ctx_size = sub_tx_size_map[ctx_size];
+    assert(depth <= MAX_TX_DEPTH);
+  }
+  return depth;
 }
 
-static INLINE TX_SIZE depth_to_tx_size(int depth, BLOCK_SIZE bsize) {
-  if (depth == 0) return max_txsize_rect_intra_lookup[bsize];
-  const int32_t tx_size_cat = intra_tx_size_cat_lookup[bsize];
-  assert(tx_size_cat + 1 - depth >= 0 && tx_size_cat + 1 - depth < TX_SIZES);
-  return (TX_SIZE)(tx_size_cat + 1 - depth);
+static INLINE TX_SIZE depth_to_tx_size(int depth, BLOCK_SIZE bsize,
+                                       int is_inter) {
+  TX_SIZE max_tx_size = get_max_rect_tx_size(bsize, is_inter);
+  TX_SIZE tx_size = max_tx_size;
+  for (int d = 0; d < depth; ++d) tx_size = sub_tx_size_map[tx_size];
+  return tx_size;
 }
 
 static INLINE TX_SIZE av1_get_uv_tx_size(const MB_MODE_INFO *mbmi,
diff --git a/av1/common/reconintra.c b/av1/common/reconintra.c
index c81ddbd..33c71e8 100644
--- a/av1/common/reconintra.c
+++ b/av1/common/reconintra.c
@@ -2815,7 +2815,6 @@
   // A block should only fail to have a matching transform if it's
   // large and rectangular (such large transform sizes aren't
   // available).
-  assert(block_width >= 32 && block_height >= 32);
   assert((block_width == wpx && block_height == hpx) ||
          (block_width == (wpx >> 1) && block_height == hpx) ||
          (block_width == wpx && block_height == (hpx >> 1)));
diff --git a/av1/decoder/decodemv.c b/av1/decoder/decodemv.c
index ad4d4d1..c6fac1b 100644
--- a/av1/decoder/decodemv.c
+++ b/av1/decoder/decodemv.c
@@ -504,7 +504,7 @@
   const BLOCK_SIZE bsize = xd->mi[0]->mbmi.sb_type;
   const int32_t tx_size_cat = is_inter ? inter_tx_size_cat_lookup[bsize]
                                        : intra_tx_size_cat_lookup[bsize];
-  const int max_depths = bsize_to_max_depth(bsize);
+  const int max_depths = bsize_to_max_depth(bsize, 0);
   const int ctx = get_tx_size_context(xd);
   FRAME_CONTEXT *ec_ctx = xd->tile_ctx;
   (void)cm;
@@ -512,7 +512,7 @@
   const int depth = aom_read_symbol(r, ec_ctx->tx_size_cdf[tx_size_cat][ctx],
                                     max_depths + 1, ACCT_STR);
   assert(depth >= 0 && depth <= max_depths);
-  const TX_SIZE tx_size = depth_to_tx_size(depth, bsize);
+  const TX_SIZE tx_size = depth_to_tx_size(depth, bsize, 0);
   return tx_size;
 }
 
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index 607357d..12d69ee 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -263,8 +263,8 @@
     const TX_SIZE tx_size = mbmi->tx_size;
     const int tx_size_ctx = get_tx_size_context(xd);
     const int32_t tx_size_cat = intra_tx_size_cat_lookup[bsize];
-    const int depth = tx_size_to_depth(tx_size, bsize);
-    const int max_depths = bsize_to_max_depth(bsize);
+    const int depth = tx_size_to_depth(tx_size, bsize, 0);
+    const int max_depths = bsize_to_max_depth(bsize, 0);
 
     assert(depth >= 0 && depth <= max_depths);
     assert(!is_inter_block(mbmi));
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index bb98ade..1b82279 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -429,11 +429,8 @@
         tx_size_from_tx_mode(mbmi->sb_type, tx_mode, is_inter_block(mbmi));
   } else {
     BLOCK_SIZE bsize = mbmi->sb_type;
-    TX_SIZE max_rect_txsize = get_max_rect_tx_size(bsize, is_inter_block(mbmi));
     TX_SIZE min_tx_size =
-        (TX_SIZE)AOMMAX((int)TX_4X4,
-                        txsize_sqr_map[max_rect_txsize] - MAX_TX_DEPTH +
-                            is_rect_tx(max_rect_txsize));
+        depth_to_tx_size(MAX_TX_DEPTH, bsize, is_inter_block(mbmi));
     mbmi->tx_size = (TX_SIZE)AOMMAX(mbmi->tx_size, min_tx_size);
   }
 }
@@ -4538,8 +4535,8 @@
     const TX_SIZE tx_size = mbmi->tx_size;
     const int tx_size_ctx = get_tx_size_context(xd);
     const int32_t tx_size_cat = intra_tx_size_cat_lookup[bsize];
-    const int depth = tx_size_to_depth(tx_size, bsize);
-    const int max_depths = bsize_to_max_depth(bsize);
+    const int depth = tx_size_to_depth(tx_size, bsize, 0);
+    const int max_depths = bsize_to_max_depth(bsize, 0);
     update_cdf(fc->tx_size_cdf[tx_size_cat][tx_size_ctx], depth,
                max_depths + 1);
   }
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 5f51fc7..1859e2d 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -2324,7 +2324,7 @@
     const int is_inter = is_inter_block(mbmi);
     const int32_t tx_size_cat = is_inter ? inter_tx_size_cat_lookup[bsize]
                                          : intra_tx_size_cat_lookup[bsize];
-    const int depth = tx_size_to_depth(tx_size, bsize);
+    const int depth = tx_size_to_depth(tx_size, bsize, is_inter);
     const int tx_size_ctx = get_tx_size_context(xd);
     int r_tx_size = x->tx_size_cost[tx_size_cat][tx_size_ctx][depth];
     return r_tx_size;
@@ -2437,7 +2437,6 @@
                             TX_TYPE tx_type, TX_SIZE tx_size, int prune) {
   const MACROBLOCKD *const xd = &x->e_mbd;
   const MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
-  const TX_SIZE max_tx_size = max_txsize_lookup[bs];
   const int is_inter = is_inter_block(mbmi);
 
   if (mbmi->ref_mv_idx > 0 && tx_type != DCT_DCT) return 1;
@@ -2449,7 +2448,6 @@
   if (is_inter && x->use_default_inter_tx_type &&
       tx_type != get_default_tx_type(0, xd, 0, tx_size))
     return 1;
-  if (max_tx_size >= TX_32X32 && tx_size == TX_4X4) return 1;
   const AV1_COMMON *const cm = &cpi->common;
   const TxSetType tx_set_type =
       get_ext_tx_set_type(tx_size, bs, is_inter, cm->reduced_tx_set_used);
@@ -2588,74 +2586,28 @@
   MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
   int64_t rd = INT64_MAX;
   int n;
-  int start_tx, end_tx;
+  int start_tx;
+  int depth;
   int64_t best_rd = INT64_MAX, last_rd = INT64_MAX;
-  const TX_SIZE max_tx_size = max_txsize_lookup[bs];
-  TX_SIZE best_tx_size = max_tx_size;
+  const int is_inter = is_inter_block(mbmi);
+  const TX_SIZE max_rect_tx_size = get_max_rect_tx_size(bs, is_inter);
+  TX_SIZE best_tx_size = max_rect_tx_size;
   TX_TYPE best_tx_type = DCT_DCT;
 #if CONFIG_TXK_SEL
   TX_TYPE best_txk_type[MAX_SB_SQUARE / (TX_SIZE_W_MIN * TX_SIZE_H_MIN)];
 #endif  // CONFIG_TXK_SEL
   const int tx_select = cm->tx_mode == TX_MODE_SELECT;
-  const int is_inter = is_inter_block(mbmi);
 
   av1_invalid_rd_stats(rd_stats);
 
-  int evaluate_rect_tx = 0;
   if (tx_select) {
-    evaluate_rect_tx = is_rect_tx_allowed(xd, mbmi);
-  } else {
-    const TX_SIZE chosen_tx_size =
-        tx_size_from_tx_mode(bs, cm->tx_mode, is_inter);
-    evaluate_rect_tx = is_rect_tx(chosen_tx_size);
-    assert(IMPLIES(evaluate_rect_tx, is_rect_tx_allowed(xd, mbmi)));
-  }
-  if (evaluate_rect_tx) {
-    TX_TYPE tx_start = DCT_DCT;
-    TX_TYPE tx_end = TX_TYPES;
-#if CONFIG_TXK_SEL
-    // The tx_type becomes dummy when lv_map is on. The tx_type search will be
-    // performed in av1_search_txk_type()
-    tx_end = DCT_DCT + 1;
-#endif
-    TX_TYPE tx_type;
-    for (tx_type = tx_start; tx_type < tx_end; ++tx_type) {
-      if (mbmi->ref_mv_idx > 0 && tx_type != DCT_DCT) continue;
-      const TX_SIZE rect_tx_size = get_max_rect_tx_size(bs, is_inter);
-      RD_STATS this_rd_stats;
-      const TxSetType tx_set_type = get_ext_tx_set_type(
-          rect_tx_size, bs, is_inter, cm->reduced_tx_set_used);
-      if (av1_ext_tx_used[tx_set_type][tx_type]) {
-        rd = txfm_yrd(cpi, x, &this_rd_stats, ref_best_rd, bs, tx_type,
-                      rect_tx_size);
-        ref_best_rd = AOMMIN(rd, ref_best_rd);
-        if (rd < best_rd) {
-#if CONFIG_TXK_SEL
-          memcpy(best_txk_type, mbmi->txk_type,
-                 sizeof(best_txk_type[0]) * MAX_SB_SQUARE /
-                     (TX_SIZE_W_MIN * TX_SIZE_H_MIN));
-#endif
-          best_tx_type = tx_type;
-          best_tx_size = rect_tx_size;
-          best_rd = rd;
-          *rd_stats = this_rd_stats;
-        }
-      }
-#if !USE_TXTYPE_SEARCH_FOR_SUB8X8_IN_CB4X4
-      const int is_inter = is_inter_block(mbmi);
-      if (mbmi->sb_type < BLOCK_8X8 && is_inter) break;
-#endif  // !USE_TXTYPE_SEARCH_FOR_SUB8X8_IN_CB4X4
-    }
-  }
-
-  if (tx_select) {
-    start_tx = max_tx_size;
-    end_tx = AOMMAX((int)TX_4X4, start_tx - MAX_TX_DEPTH + evaluate_rect_tx);
+    start_tx = max_rect_tx_size;
+    depth = 0;
   } else {
     const TX_SIZE chosen_tx_size =
         tx_size_from_tx_mode(bs, cm->tx_mode, is_inter);
     start_tx = chosen_tx_size;
-    end_tx = chosen_tx_size;
+    depth = MAX_TX_DEPTH;
   }
 
   int prune = 0;
@@ -2665,8 +2617,7 @@
   }
 
   last_rd = INT64_MAX;
-  for (n = start_tx; n >= end_tx; --n) {
-    if (is_rect_tx(n)) break;
+  for (n = start_tx; depth <= MAX_TX_DEPTH; depth++, n = sub_tx_size_map[n]) {
     TX_TYPE tx_start = DCT_DCT;
     TX_TYPE tx_end = TX_TYPES;
 #if CONFIG_TXK_SEL
@@ -2683,8 +2634,8 @@
       // Early termination in transform size search.
       if (cpi->sf.tx_size_search_breakout &&
           (rd == INT64_MAX ||
-           (this_rd_stats.skip == 1 && tx_type != DCT_DCT && n < start_tx) ||
-           (n < (int)max_tx_size && rd > last_rd))) {
+           (this_rd_stats.skip == 1 && tx_type != DCT_DCT && n != start_tx) ||
+           (n != (int)start_tx && rd > last_rd))) {
         break;
       }