Various RD fixes related to 4:1 transforms

The fixes in rdopt.c improves the coding performance of
4:1 transforms significantly.

Change-Id: I0e8db93e3f6d9bf0b2de01f2ce83c305d78d2262
diff --git a/av1/decoder/decodemv.c b/av1/decoder/decodemv.c
index 8b17582..1b23e77 100644
--- a/av1/decoder/decodemv.c
+++ b/av1/decoder/decodemv.c
@@ -871,10 +871,10 @@
   MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
   const int inter_block = is_inter_block(mbmi);
 #if !CONFIG_TXK_SEL
-  const TX_SIZE sqr_up_tx_size =
-      txsize_sqr_up_map[max_txsize_rect_lookup[xd->mi[0]->mbmi.sb_type]];
+  const TX_SIZE mtx_size =
+      get_max_rect_tx_size(xd->mi[0]->mbmi.sb_type, inter_block);
   const TX_SIZE tx_size =
-      inter_block ? AOMMAX(sub_tx_size_map[sqr_up_tx_size], mbmi->min_tx_size)
+      inter_block ? AOMMAX(sub_tx_size_map[mtx_size], mbmi->min_tx_size)
                   : mbmi->tx_size;
 #endif  // !CONFIG_TXK_SEL
   FRAME_CONTEXT *ec_ctx = xd->tile_ctx;
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index bf552a4..7d26aa4 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -1095,10 +1095,10 @@
   MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
   const int is_inter = is_inter_block(mbmi);
 #if !CONFIG_TXK_SEL
-  const TX_SIZE sqr_up_tx_size =
-      txsize_sqr_up_map[max_txsize_rect_lookup[xd->mi[0]->mbmi.sb_type]];
+  const TX_SIZE mtx_size =
+      get_max_rect_tx_size(xd->mi[0]->mbmi.sb_type, is_inter);
   const TX_SIZE tx_size =
-      is_inter ? AOMMAX(sub_tx_size_map[sqr_up_tx_size], mbmi->min_tx_size)
+      is_inter ? AOMMAX(sub_tx_size_map[mtx_size], mbmi->min_tx_size)
                : mbmi->tx_size;
 #endif  // !CONFIG_TXK_SEL
   FRAME_CONTEXT *ec_ctx = xd->tile_ctx;
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 287c521..b5fd600 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -4381,31 +4381,34 @@
                           tx_size);
   } else {
     const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
-    const int bsl = tx_size_wide_unit[sub_txs];
-    int step = tx_size_wide_unit[sub_txs] * tx_size_high_unit[sub_txs];
-    int i;
+    const int bsw = tx_size_wide_unit[sub_txs];
+    const int bsh = tx_size_high_unit[sub_txs];
+    const int step = bsh * bsw;
     RD_STATS pn_rd_stats;
     int64_t this_rd = 0;
-    assert(bsl > 0);
+    assert(bsw > 0 && bsh > 0);
 
-    for (i = 0; i < 4; ++i) {
-      int offsetr = blk_row + (i >> 1) * bsl;
-      int offsetc = blk_col + (i & 0x01) * bsl;
+    for (int row = 0; row < tx_size_high_unit[tx_size]; row += bsh) {
+      for (int col = 0; col < tx_size_wide_unit[tx_size]; col += bsw) {
+        const int offsetr = blk_row + row;
+        const int offsetc = blk_col + col;
 
-      if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue;
+        if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue;
 
-      av1_init_rd_stats(&pn_rd_stats);
-      tx_block_yrd(cpi, x, offsetr, offsetc, plane, block, sub_txs, plane_bsize,
-                   depth + 1, above_ctx, left_ctx, tx_above, tx_left,
-                   ref_best_rd - this_rd, &pn_rd_stats, fast);
-      if (pn_rd_stats.rate == INT_MAX) {
-        av1_invalid_rd_stats(rd_stats);
-        return;
+        av1_init_rd_stats(&pn_rd_stats);
+        tx_block_yrd(cpi, x, offsetr, offsetc, plane, block, sub_txs,
+                     plane_bsize, depth + 1, above_ctx, left_ctx, tx_above,
+                     tx_left, ref_best_rd - this_rd, &pn_rd_stats, fast);
+        if (pn_rd_stats.rate == INT_MAX) {
+          av1_invalid_rd_stats(rd_stats);
+          return;
+        }
+        av1_merge_rd_stats(rd_stats, &pn_rd_stats);
+        this_rd += RDCOST(x->rdmult, pn_rd_stats.rate, pn_rd_stats.dist);
+        block += step;
       }
-      av1_merge_rd_stats(rd_stats, &pn_rd_stats);
-      this_rd += RDCOST(x->rdmult, pn_rd_stats.rate, pn_rd_stats.dist);
-      block += step;
     }
+
     if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH)
       rd_stats->rate += x->txfm_partition_cost[ctx][1];
   }
@@ -4771,7 +4774,7 @@
   TX_TYPE tx_type, best_tx_type = DCT_DCT;
   const int is_inter = is_inter_block(mbmi);
   TX_SIZE best_tx_size[MAX_MIB_SIZE][MAX_MIB_SIZE];
-  TX_SIZE best_tx = max_txsize_lookup[bsize];
+  TX_SIZE best_tx = max_txsize_rect_lookup[bsize];
   TX_SIZE best_min_tx_size = TX_SIZES_ALL;
   uint8_t best_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE * 8];
   TX_TYPE txk_start = DCT_DCT;
@@ -4783,10 +4786,8 @@
   const int n4 = bsize_to_num_blk(bsize);
   int idx, idy;
   int prune = 0;
-  const TX_SIZE sqr_up_tx_size =
-      txsize_sqr_up_map[max_txsize_rect_lookup[bsize]];
   // Get the tx_size 1 level down
-  TX_SIZE min_tx_size = sub_tx_size_map[sqr_up_tx_size];
+  TX_SIZE min_tx_size = sub_tx_size_map[max_txsize_rect_lookup[bsize]];
   const TxSetType tx_set_type = get_ext_tx_set_type(
       min_tx_size, bsize, is_inter, cm->reduced_tx_set_used);
   int within_border = (mi_row + mi_size_high[bsize] <= cm->mi_rows) &&
@@ -4967,21 +4968,21 @@
     av1_set_txb_context(x, plane, block, tx_size, ta, tl);
   } else {
     const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
-    const int bsl = tx_size_wide_unit[sub_txs];
-    int step = tx_size_wide_unit[sub_txs] * tx_size_high_unit[sub_txs];
-    int i;
-
-    assert(bsl > 0);
-
-    for (i = 0; i < 4; ++i) {
-      int offsetr = blk_row + (i >> 1) * bsl;
-      int offsetc = blk_col + (i & 0x01) * bsl;
-
-      if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue;
-
-      tx_block_rd(cpi, x, offsetr, offsetc, plane, block, sub_txs, plane_bsize,
-                  above_ctx, left_ctx, rd_stats, fast);
-      block += step;
+    assert(IMPLIES(tx_size <= TX_4X4, sub_txs == tx_size));
+    assert(IMPLIES(tx_size > TX_4X4, sub_txs < tx_size));
+    const int bsw = tx_size_wide_unit[sub_txs];
+    const int bsh = tx_size_high_unit[sub_txs];
+    const int step = bsh * bsw;
+    assert(bsw > 0 && bsh > 0);
+    for (int row = 0; row < tx_size_high_unit[tx_size]; row += bsh) {
+      for (int col = 0; col < tx_size_wide_unit[tx_size]; col += bsw) {
+        const int offsetr = blk_row + row;
+        const int offsetc = blk_col + col;
+        if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue;
+        tx_block_rd(cpi, x, offsetr, offsetc, plane, block, sub_txs,
+                    plane_bsize, above_ctx, left_ctx, rd_stats, fast);
+        block += step;
+      }
     }
   }
 }