Add the option of using 1:4/4:1 tx_size+sb_type

Change-Id: I96e5ff72caee8935efb7535afa3a534175bc425c
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index 550afde..8656b28 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -279,8 +279,8 @@
     // TODO(yuec): set correct txfm partition update for qttx
   } else {
     const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
-    const int bsl = tx_size_wide_unit[sub_txs];
-    int i;
+    const int bsw = tx_size_wide_unit[sub_txs];
+    const int bsh = tx_size_high_unit[sub_txs];
 
 #if CONFIG_NEW_MULTISYMBOL
     aom_write_symbol(w, 1, ec_ctx->txfm_partition_cdf[ctx], 2);
@@ -294,13 +294,14 @@
       return;
     }
 
-    assert(bsl > 0);
-    for (i = 0; i < 4; ++i) {
-      int offsetr = blk_row + (i >> 1) * bsl;
-      int offsetc = blk_col + (i & 0x01) * bsl;
-      write_tx_size_vartx(cm, xd, mbmi, sub_txs, depth + 1, offsetr, offsetc,
-                          w);
-    }
+    assert(bsw > 0 && bsh > 0);
+    for (int row = 0; row < tx_size_high_unit[tx_size]; row += bsh)
+      for (int col = 0; col < tx_size_wide_unit[tx_size]; col += bsw) {
+        int offsetr = blk_row + row;
+        int offsetc = blk_col + col;
+        write_tx_size_vartx(cm, xd, mbmi, sub_txs, depth + 1, offsetr, offsetc,
+                            w);
+      }
   }
 }
 
@@ -654,21 +655,24 @@
 #endif
   } else {
     const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
-    const int bsl = tx_size_wide_unit[sub_txs];
-    int i;
+    const int bsw = tx_size_wide_unit[sub_txs];
+    const int bsh = tx_size_high_unit[sub_txs];
 
-    assert(bsl > 0);
+    assert(bsw > 0 && bsh > 0);
 
-    for (i = 0; i < 4; ++i) {
-      const int offsetr = blk_row + (i >> 1) * bsl;
-      const int offsetc = blk_col + (i & 0x01) * bsl;
-      const int step = tx_size_wide_unit[sub_txs] * tx_size_high_unit[sub_txs];
+    for (int r = 0; r < tx_size_high_unit[tx_size]; r += bsh) {
+      for (int c = 0; c < tx_size_wide_unit[tx_size]; c += bsw) {
+        const int offsetr = blk_row + r;
+        const int offsetc = blk_col + c;
+        const int step = bsh * bsw;
 
-      if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue;
+        if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue;
 
-      pack_txb_tokens(w, cm, x, tp, tok_end, xd, mbmi, plane, plane_bsize,
-                      bit_depth, block, offsetr, offsetc, sub_txs, token_stats);
-      block += step;
+        pack_txb_tokens(w, cm, x, tp, tok_end, xd, mbmi, plane, plane_bsize,
+                        bit_depth, block, offsetr, offsetc, sub_txs,
+                        token_stats);
+        block += step;
+      }
     }
   }
 }
@@ -716,31 +720,23 @@
 #else
     const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
 #endif
-    const int bsl = tx_size_wide_unit[sub_txs];
-    int i;
+    const int bsw = tx_size_wide_unit[sub_txs];
+    const int bsh = tx_size_high_unit[sub_txs];
 
-    assert(bsl > 0);
+    assert(bsw > 0 && bsh > 0);
 
-    for (i = 0; i < 4; ++i) {
-#if CONFIG_RECT_TX_EXT
-      int is_wide_tx = tx_size_wide_unit[sub_txs] > tx_size_high_unit[sub_txs];
-      const int offsetr =
-          is_qttx ? (is_wide_tx ? i * tx_size_high_unit[sub_txs] : 0)
-                  : blk_row + (i >> 1) * bsl;
-      const int offsetc =
-          is_qttx ? (is_wide_tx ? 0 : i * tx_size_wide_unit[sub_txs])
-                  : blk_col + (i & 0x01) * bsl;
-#else
-      const int offsetr = blk_row + (i >> 1) * bsl;
-      const int offsetc = blk_col + (i & 0x01) * bsl;
-#endif
-      const int step = tx_size_wide_unit[sub_txs] * tx_size_high_unit[sub_txs];
+    for (int r = 0; r < tx_size_high_unit[tx_size]; r += bsh) {
+      for (int c = 0; c < tx_size_wide_unit[tx_size]; c += bsw) {
+        const int offsetr = blk_row + r;
+        const int offsetc = blk_col + c;
+        const int step = bsh * bsw;
 
-      if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue;
+        if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue;
 
-      pack_txb_tokens(w, tp, tok_end, xd, mbmi, plane, plane_bsize, bit_depth,
-                      block, offsetr, offsetc, sub_txs, token_stats);
-      block += step;
+        pack_txb_tokens(w, tp, tok_end, xd, mbmi, plane, plane_bsize, bit_depth,
+                        block, offsetr, offsetc, sub_txs, token_stats);
+        block += step;
+      }
     }
   }
 }
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index 8e6aac2..f0619b9 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -4626,8 +4626,8 @@
                           xd->left_txfm_context + blk_row, tx_size, tx_size);
   } else {
     const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
-    const int bs = tx_size_wide_unit[sub_txs];
-    int i;
+    const int bsw = tx_size_wide_unit[sub_txs];
+    const int bsh = tx_size_high_unit[sub_txs];
 
     ++counts->txfm_partition[ctx][1];
 #if CONFIG_NEW_MULTISYMBOL
@@ -4644,11 +4644,14 @@
       return;
     }
 
-    for (i = 0; i < 4; ++i) {
-      int offsetr = (i >> 1) * bs;
-      int offsetc = (i & 0x01) * bs;
-      update_txfm_count(x, xd, counts, sub_txs, depth + 1, blk_row + offsetr,
-                        blk_col + offsetc, allow_update_cdf);
+    for (int row = 0; row < tx_size_high_unit[tx_size]; row += bsh) {
+      for (int col = 0; col < tx_size_wide_unit[tx_size]; col += bsw) {
+        int offsetr = row;
+        int offsetc = col;
+
+        update_txfm_count(x, xd, counts, sub_txs, depth + 1, blk_row + offsetr,
+                          blk_col + offsetc, allow_update_cdf);
+      }
     }
   }
 }
@@ -4936,14 +4939,15 @@
         tx_partition_count_update(cm, x, bsize, mi_row, mi_col, td->counts,
                                   tile_data->allow_update_cdf);
       } else {
-        if (tx_size != max_txsize_rect_lookup[bsize]) ++x->txb_split_count;
+        if (tx_size != get_max_rect_tx_size(bsize, 0)) ++x->txb_split_count;
       }
 
 #if CONFIG_RECT_TX_EXT
       if (is_quarter_tx_allowed(xd, mbmi, is_inter) &&
-          quarter_txsize_lookup[bsize] != max_txsize_rect_lookup[bsize] &&
+          quarter_txsize_lookup[bsize] !=
+              get_max_rect_tx_size(bsize, is_inter) &&
           (mbmi->tx_size == quarter_txsize_lookup[bsize] ||
-           mbmi->tx_size == max_txsize_rect_lookup[bsize])) {
+           mbmi->tx_size == get_max_rect_tx_size(bsize, is_inter))) {
         const int use_qttx = mbmi->tx_size == quarter_txsize_lookup[bsize];
         ++td->counts->quarter_tx_size[use_qttx];
 #if CONFIG_NEW_MULTISYMBOL
@@ -4975,7 +4979,8 @@
             mi_8x8[mis * j + i]->mbmi.tx_size = intra_tx_size;
 
       mbmi->min_tx_size = get_min_tx_size(intra_tx_size);
-      if (intra_tx_size != max_txsize_rect_lookup[bsize]) ++x->txb_split_count;
+      if (intra_tx_size != get_max_rect_tx_size(bsize, is_inter))
+        ++x->txb_split_count;
     }
 
 #if !CONFIG_TXK_SEL
diff --git a/av1/encoder/encodemb.c b/av1/encoder/encodemb.c
index 1c9f3f4..a95fa6f 100644
--- a/av1/encoder/encodemb.c
+++ b/av1/encoder/encodemb.c
@@ -670,30 +670,22 @@
     assert(IMPLIES(tx_size > TX_4X4, sub_txs < tx_size));
 #endif
     // This is the square transform block partition entry point.
-    int bsl = tx_size_wide_unit[sub_txs];
-    int i;
-    assert(bsl > 0);
+    const int bsw = tx_size_wide_unit[sub_txs];
+    const int bsh = tx_size_high_unit[sub_txs];
+    const int step = bsh * bsw;
+    assert(bsw > 0 && bsh > 0);
 
-    for (i = 0; i < 4; ++i) {
-#if CONFIG_RECT_TX_EXT
-      int is_wide_tx = tx_size_wide_unit[sub_txs] > tx_size_high_unit[sub_txs];
-      const int offsetr =
-          is_qttx ? (is_wide_tx ? i * tx_size_high_unit[sub_txs] : 0)
-                  : blk_row + ((i >> 1) * bsl);
-      const int offsetc =
-          is_qttx ? (is_wide_tx ? 0 : i * tx_size_wide_unit[sub_txs])
-                  : blk_col + ((i & 0x01) * bsl);
-#else
-      const int offsetr = blk_row + ((i >> 1) * bsl);
-      const int offsetc = blk_col + ((i & 0x01) * bsl);
-#endif
-      int step = tx_size_wide_unit[sub_txs] * tx_size_high_unit[sub_txs];
+    for (int row = 0; row < tx_size_high_unit[tx_size]; row += bsh) {
+      for (int col = 0; col < tx_size_wide_unit[tx_size]; col += bsw) {
+        const int offsetr = blk_row + row;
+        const int offsetc = blk_col + col;
 
-      if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue;
+        if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue;
 
-      encode_block_inter(plane, block, offsetr, offsetc, plane_bsize, sub_txs,
-                         arg);
-      block += step;
+        encode_block_inter(plane, block, offsetr, offsetc, plane_bsize, sub_txs,
+                           arg);
+        block += step;
+      }
     }
   }
 }
diff --git a/av1/encoder/hybrid_fwd_txfm.c b/av1/encoder/hybrid_fwd_txfm.c
index 76b417c..fa9a41b 100644
--- a/av1/encoder/hybrid_fwd_txfm.c
+++ b/av1/encoder/hybrid_fwd_txfm.c
@@ -145,7 +145,7 @@
 }
 #endif  // CONFIG_TX64X64
 
-#if CONFIG_RECT_TX_EXT
+#if CONFIG_RECT_TX_EXT || (CONFIG_EXT_PARTITION_TYPES && USE_RECT_TX_EXT)
 static void fwd_txfm_16x4(const int16_t *src_diff, tran_low_t *coeff,
                           int diff_stride, TxfmParam *txfm_param) {
   av1_fht16x4(src_diff, coeff, diff_stride, txfm_param);
@@ -503,7 +503,7 @@
       fwd_txfm_32x16(src_diff, coeff, diff_stride, txfm_param);
       break;
     case TX_4X4: fwd_txfm_4x4(src_diff, coeff, diff_stride, txfm_param); break;
-#if CONFIG_RECT_TX_EXT
+#if CONFIG_RECT_TX_EXT || (CONFIG_EXT_PARTITION_TYPES && USE_RECT_TX_EXT)
     case TX_4X16:
       fwd_txfm_4x16(src_diff, coeff, diff_stride, txfm_param);
       break;
diff --git a/av1/encoder/rd.c b/av1/encoder/rd.c
index c001277..f74fac6 100644
--- a/av1/encoder/rd.c
+++ b/av1/encoder/rd.c
@@ -897,7 +897,7 @@
       for (i = 0; i < num_4x4_h; i += 4)
         t_left[i] = !!*(const uint32_t *)&left[i];
       break;
-#if CONFIG_RECT_TX_EXT
+#if CONFIG_RECT_TX_EXT || (CONFIG_EXT_PARTITION_TYPES && USE_RECT_TX_EXT)
     case TX_4X16:
       memcpy(t_above, above, sizeof(ENTROPY_CONTEXT) * num_4x4_w);
       for (i = 0; i < num_4x4_h; i += 4)
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index abfa631..4954c8f 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -2663,7 +2663,7 @@
     TX_TYPE tx_type;
     for (tx_type = tx_start; tx_type < tx_end; ++tx_type) {
       if (mbmi->ref_mv_idx > 0 && tx_type != DCT_DCT) continue;
-      const TX_SIZE rect_tx_size = max_txsize_rect_lookup[bs];
+      const TX_SIZE rect_tx_size = get_max_rect_tx_size(bs, is_inter);
       RD_STATS this_rd_stats;
       const TxSetType tx_set_type = get_ext_tx_set_type(
           rect_tx_size, bs, is_inter, cm->reduced_tx_set_used);
@@ -3890,7 +3890,6 @@
   int64_t this_rd = INT64_MAX;
   ENTROPY_CONTEXT *pta = ta + blk_col;
   ENTROPY_CONTEXT *ptl = tl + blk_row;
-  int i;
   int ctx = txfm_partition_context(tx_above + blk_col, tx_left + blk_row,
                                    mbmi->sb_type, tx_size);
   int64_t sum_rd = INT64_MAX;
@@ -3904,7 +3903,7 @@
 #if CONFIG_RECT_TX_EXT
   TX_SIZE quarter_txsize = quarter_txsize_lookup[mbmi->sb_type];
   int check_qttx = is_quarter_tx_allowed(xd, mbmi, is_inter_block(mbmi)) &&
-                   tx_size == max_txsize_rect_lookup[mbmi->sb_type] &&
+                   tx_size == get_max_rect_tx_size(mbmi->sb_type, 1) &&
                    quarter_txsize != tx_size;
   int is_qttx_picked = 0;
   int eobs_qttx[2] = { 0, 0 };
@@ -4102,13 +4101,14 @@
 #endif  // CONFIG_MRC_TX
       ) {
     const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
-    const int bsl = tx_size_wide_unit[sub_txs];
-    int sub_step = tx_size_wide_unit[sub_txs] * tx_size_high_unit[sub_txs];
+    const int bsw = tx_size_wide_unit[sub_txs];
+    const int bsh = tx_size_high_unit[sub_txs];
+    int sub_step = bsw * bsh;
     RD_STATS this_rd_stats;
     int this_cost_valid = 1;
     int64_t tmp_rd = 0;
 #if CONFIG_DIST_8X8
-    int sub8x8_eob[4];
+    int sub8x8_eob[4] = { 0, 0, 0, 0 };
 #endif
     sum_rd_stats.rate = x->txfm_partition_cost[ctx][1];
 
@@ -4116,29 +4116,35 @@
 
     ref_best_rd = AOMMIN(this_rd, ref_best_rd);
 
-    for (i = 0; i < 4 && this_cost_valid; ++i) {
-      int offsetr = blk_row + (i >> 1) * bsl;
-      int offsetc = blk_col + (i & 0x01) * bsl;
+    for (int r = 0; r < tx_size_high_unit[tx_size]; r += bsh) {
+      for (int c = 0; c < tx_size_wide_unit[tx_size]; c += bsw) {
+        int offsetr = blk_row + r;
+        int offsetc = blk_col + c;
 
-      if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue;
+        if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue;
 
-      select_tx_block(cpi, x, offsetr, offsetc, plane, block, sub_txs,
-                      depth + 1, plane_bsize, ta, tl, tx_above, tx_left,
-                      &this_rd_stats, ref_best_rd - tmp_rd, &this_cost_valid,
-                      fast, 0);
+        select_tx_block(cpi, x, offsetr, offsetc, plane, block, sub_txs,
+                        depth + 1, plane_bsize, ta, tl, tx_above, tx_left,
+                        &this_rd_stats, ref_best_rd - tmp_rd, &this_cost_valid,
+                        fast, 0);
 #if CONFIG_DIST_8X8
-      if (x->using_dist_8x8 && plane == 0 && tx_size == TX_8X8) {
-        sub8x8_eob[i] = p->eobs[block];
-      }
-#endif  // CONFIG_DIST_8X8
-      av1_merge_rd_stats(&sum_rd_stats, &this_rd_stats);
-
-      tmp_rd = RDCOST(x->rdmult, sum_rd_stats.rate, sum_rd_stats.dist);
-#if CONFIG_DIST_8X8
-      if (!x->using_dist_8x8)
+        if (!x->using_dist_8x8)
 #endif
-        if (this_rd < tmp_rd) break;
-      block += sub_step;
+          if (!this_cost_valid) break;
+#if CONFIG_DIST_8X8
+        if (x->using_dist_8x8 && plane == 0 && tx_size == TX_8X8) {
+          sub8x8_eob[2 * (r / bsh) + (c / bsw)] = p->eobs[block];
+        }
+#endif  // CONFIG_DIST_8X8
+        av1_merge_rd_stats(&sum_rd_stats, &this_rd_stats);
+
+        tmp_rd = RDCOST(x->rdmult, sum_rd_stats.rate, sum_rd_stats.dist);
+#if CONFIG_DIST_8X8
+        if (!x->using_dist_8x8)
+#endif
+          if (this_rd < tmp_rd) break;
+        block += sub_step;
+      }
     }
 #if CONFIG_DIST_8X8
     if (x->using_dist_8x8 && this_cost_valid && plane == 0 &&
@@ -4158,7 +4164,7 @@
       const int pred_idx = (blk_row * pred_stride + blk_col)
                            << tx_size_wide_log2[0];
       int16_t *pred = &pd->pred[pred_idx];
-      int j;
+      int i, j;
       int row, col;
 
 #if CONFIG_HIGHBITDEPTH
@@ -4325,7 +4331,7 @@
     const BLOCK_SIZE plane_bsize = get_plane_block_size(bsize, pd);
     const int mi_width = block_size_wide[plane_bsize] >> tx_size_wide_log2[0];
     const int mi_height = block_size_high[plane_bsize] >> tx_size_high_log2[0];
-    const TX_SIZE max_tx_size = max_txsize_rect_lookup[plane_bsize];
+    const TX_SIZE max_tx_size = get_max_rect_tx_size(plane_bsize, 1);
     const int bh = tx_size_high_unit[max_tx_size];
     const int bw = tx_size_wide_unit[max_tx_size];
     int idx, idy;
@@ -4698,7 +4704,8 @@
   DECLARE_ALIGNED(32, tran_low_t, DCT_coefs[32 * 32]);
   TxfmParam param;
   param.tx_type = DCT_DCT;
-  param.tx_size = max_txsize_rect_lookup[bsize];
+  param.tx_size =
+      get_max_rect_tx_size(bsize, is_inter_block(&x->e_mbd.mi[0]->mbmi));
   param.bd = xd->bd;
   param.is_hbd = get_bitdepth_data_path_index(xd);
   param.lossless = 0;
@@ -4740,7 +4747,7 @@
   MACROBLOCKD *const xd = &x->e_mbd;
   MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
   const int n4 = bsize_to_num_blk(bsize);
-  const TX_SIZE tx_size = max_txsize_rect_lookup[bsize];
+  const TX_SIZE tx_size = get_max_rect_tx_size(bsize, is_inter_block(mbmi));
   mbmi->tx_type = DCT_DCT;
   for (int idy = 0; idy < xd->n8_h; ++idy)
     for (int idx = 0; idx < xd->n8_w; ++idx)
@@ -5042,7 +5049,8 @@
       const int mi_width = block_size_wide[plane_bsize] >> tx_size_wide_log2[0];
       const int mi_height =
           block_size_high[plane_bsize] >> tx_size_high_log2[0];
-      const TX_SIZE max_tx_size = max_txsize_rect_lookup[plane_bsize];
+      const TX_SIZE max_tx_size =
+          get_max_rect_tx_size(plane_bsize, is_inter_block(mbmi));
       const int bh = tx_size_high_unit[max_tx_size];
       const int bw = tx_size_wide_unit[max_tx_size];
       int idx, idy;
@@ -5052,7 +5060,6 @@
       ENTROPY_CONTEXT tl[2 * MAX_MIB_SIZE];
       RD_STATS pn_rd_stats;
       av1_init_rd_stats(&pn_rd_stats);
-
       av1_get_entropy_contexts(bsize, 0, pd, ta, tl);
 
       for (idy = 0; idy < mi_height; idy += bh) {
diff --git a/av1/encoder/tokenize.c b/av1/encoder/tokenize.c
index ef18298..4bc36c9 100644
--- a/av1/encoder/tokenize.c
+++ b/av1/encoder/tokenize.c
@@ -612,32 +612,23 @@
     // Half the block size in transform block unit.
     const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
 #endif
-    const int bsl = tx_size_wide_unit[sub_txs];
-    int i;
+    const int bsw = tx_size_wide_unit[sub_txs];
+    const int bsh = tx_size_high_unit[sub_txs];
+    const int step = bsw * bsh;
 
-    assert(bsl > 0);
+    assert(bsw > 0 && bsh > 0);
 
-    for (i = 0; i < 4; ++i) {
-#if CONFIG_RECT_TX_EXT
-      int is_wide_tx = tx_size_wide_unit[sub_txs] > tx_size_high_unit[sub_txs];
-      const int offsetr =
-          is_qttx ? (is_wide_tx ? i * tx_size_high_unit[sub_txs] : 0)
-                  : blk_row + ((i >> 1) * bsl);
-      const int offsetc =
-          is_qttx ? (is_wide_tx ? 0 : i * tx_size_wide_unit[sub_txs])
-                  : blk_col + ((i & 0x01) * bsl);
-#else
-      const int offsetr = blk_row + ((i >> 1) * bsl);
-      const int offsetc = blk_col + ((i & 0x01) * bsl);
-#endif
+    for (int row = 0; row < tx_size_high_unit[tx_size]; row += bsh) {
+      for (int col = 0; col < tx_size_wide_unit[tx_size]; col += bsw) {
+        const int offsetr = blk_row + row;
+        const int offsetc = blk_col + col;
 
-      int step = tx_size_wide_unit[sub_txs] * tx_size_high_unit[sub_txs];
+        if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue;
 
-      if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue;
-
-      tokenize_vartx(td, t, dry_run, sub_txs, plane_bsize, offsetr, offsetc,
-                     block, plane, arg);
-      block += step;
+        tokenize_vartx(td, t, dry_run, sub_txs, plane_bsize, offsetr, offsetc,
+                       block, plane, arg);
+        block += step;
+      }
     }
   }
 }