Make cb4x4 mode support supertx

This commit makes the cb4x4 mode support supertx operation.

Change-Id: I1a713b2268c1029aebeb43aa6aeb0fa37b16810f
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index 4515f38..3298198 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -376,9 +376,11 @@
   x->mv_row_max = (cm->mi_rows - mi_row_pred) * MI_SIZE + AOM_INTERP_EXTEND;
   x->mv_col_max = (cm->mi_cols - mi_col_pred) * MI_SIZE + AOM_INTERP_EXTEND;
 
-  // Set up distance of MB to edge of frame in 1/8th pel units.
+// Set up distance of MB to edge of frame in 1/8th pel units.
+#if !CONFIG_CB4X4
   assert(!(mi_col_pred & (mi_width - mi_size_wide[BLOCK_8X8])) &&
          !(mi_row_pred & (mi_height - mi_size_high[BLOCK_8X8])));
+#endif
   set_mi_row_col(xd, tile, mi_row_pred, mi_height, mi_col_pred, mi_width,
                  cm->mi_rows, cm->mi_cols);
   xd->up_available = (mi_row_ori > tile->mi_row_start);
@@ -1403,6 +1405,11 @@
   struct macroblock_plane *const p = x->plane;
   struct macroblockd_plane *const pd = xd->plane;
   int hbs = mi_size_wide[bsize] / 2;
+#if CONFIG_CB4X4
+  const int unify_bsize = 1;
+#else
+  const int unify_bsize = 0;
+#endif
   PARTITION_TYPE partition = pc_tree->partitioning;
   BLOCK_SIZE subsize = get_subsize(bsize, partition);
   int i;
@@ -1426,7 +1433,7 @@
       set_offsets_supertx(cpi, td, tile, mi_row, mi_col, subsize);
       update_state_supertx(cpi, td, &pc_tree->vertical[0], mi_row, mi_col,
                            subsize, dry_run);
-      if (mi_col + hbs < cm->mi_cols && bsize > BLOCK_8X8) {
+      if (mi_col + hbs < cm->mi_cols && (bsize > BLOCK_8X8 || unify_bsize)) {
         set_offsets_supertx(cpi, td, tile, mi_row, mi_col + hbs, subsize);
         update_state_supertx(cpi, td, &pc_tree->vertical[1], mi_row,
                              mi_col + hbs, subsize, dry_run);
@@ -1437,7 +1444,7 @@
       set_offsets_supertx(cpi, td, tile, mi_row, mi_col, subsize);
       update_state_supertx(cpi, td, &pc_tree->horizontal[0], mi_row, mi_col,
                            subsize, dry_run);
-      if (mi_row + hbs < cm->mi_rows && bsize > BLOCK_8X8) {
+      if (mi_row + hbs < cm->mi_rows && (bsize > BLOCK_8X8 || unify_bsize)) {
         set_offsets_supertx(cpi, td, tile, mi_row + hbs, mi_col, subsize);
         update_state_supertx(cpi, td, &pc_tree->horizontal[1], mi_row + hbs,
                              mi_col, subsize, dry_run);
@@ -1445,7 +1452,7 @@
       pmc = &pc_tree->horizontal_supertx;
       break;
     case PARTITION_SPLIT:
-      if (bsize == BLOCK_8X8) {
+      if (bsize == BLOCK_8X8 && !unify_bsize) {
         set_offsets_supertx(cpi, td, tile, mi_row, mi_col, subsize);
         update_state_supertx(cpi, td, pc_tree->leaf_split[0], mi_row, mi_col,
                              subsize, dry_run);
@@ -1558,6 +1565,11 @@
   const int hbs = mi_size_wide[bsize] / 2;
   PARTITION_TYPE partition = pc_tree->partitioning;
   BLOCK_SIZE subsize = get_subsize(bsize, partition);
+#if CONFIG_CB4X4
+  const int unify_bsize = 1;
+#else
+  const int unify_bsize = 0;
+#endif
 #if CONFIG_EXT_PARTITION_TYPES
   int i;
 #endif
@@ -1570,17 +1582,17 @@
       break;
     case PARTITION_VERT:
       update_supertx_param(td, &pc_tree->vertical[0], best_tx, supertx_size);
-      if (mi_col + hbs < cm->mi_cols && bsize > BLOCK_8X8)
+      if (mi_col + hbs < cm->mi_cols && (bsize > BLOCK_8X8 || unify_bsize))
         update_supertx_param(td, &pc_tree->vertical[1], best_tx, supertx_size);
       break;
     case PARTITION_HORZ:
       update_supertx_param(td, &pc_tree->horizontal[0], best_tx, supertx_size);
-      if (mi_row + hbs < cm->mi_rows && bsize > BLOCK_8X8)
+      if (mi_row + hbs < cm->mi_rows && (bsize > BLOCK_8X8 || unify_bsize))
         update_supertx_param(td, &pc_tree->horizontal[1], best_tx,
                              supertx_size);
       break;
     case PARTITION_SPLIT:
-      if (bsize == BLOCK_8X8) {
+      if (bsize == BLOCK_8X8 && !unify_bsize) {
         update_supertx_param(td, pc_tree->leaf_split[0], best_tx, supertx_size);
       } else {
         update_supertx_param_sb(cpi, td, mi_row, mi_col, subsize, best_tx,
@@ -3969,8 +3981,9 @@
                      subsize, &pc_tree->horizontal[0], best_rdc.rdcost);
 
 #if CONFIG_SUPERTX
-    abort_flag = (sum_rdc.rdcost >= best_rd && bsize > BLOCK_8X8) ||
-                 (sum_rdc.rate == INT_MAX && bsize == BLOCK_8X8);
+    abort_flag =
+        (sum_rdc.rdcost >= best_rd && (bsize > BLOCK_8X8 || unify_bsize)) ||
+        (sum_rdc.rate == INT_MAX && bsize == BLOCK_8X8);
     if (sum_rdc.rdcost < INT64_MAX &&
 #else
     if (sum_rdc.rdcost < best_rdc.rdcost &&
@@ -4113,8 +4126,9 @@
 #endif
                      subsize, &pc_tree->vertical[0], best_rdc.rdcost);
 #if CONFIG_SUPERTX
-    abort_flag = (sum_rdc.rdcost >= best_rd && bsize > BLOCK_8X8) ||
-                 (sum_rdc.rate == INT_MAX && bsize == BLOCK_8X8);
+    abort_flag =
+        (sum_rdc.rdcost >= best_rd && (bsize > BLOCK_8X8 || unify_bsize)) ||
+        (sum_rdc.rate == INT_MAX && bsize == BLOCK_8X8);
     if (sum_rdc.rdcost < INT64_MAX &&
 #else
     if (sum_rdc.rdcost < best_rdc.rdcost &&
@@ -5576,15 +5590,21 @@
                           int mi_row, int mi_col, BLOCK_SIZE bsize,
                           PC_TREE *pc_tree) {
   const AV1_COMMON *const cm = &cpi->common;
-
   const int hbs = mi_size_wide[bsize] / 2;
   const PARTITION_TYPE partition = pc_tree->partitioning;
   const BLOCK_SIZE subsize = get_subsize(bsize, partition);
 #if CONFIG_EXT_PARTITION_TYPES
   int i;
 #endif
+#if CONFIG_CB4X4
+  const int unify_bsize = 1;
+#else
+  const int unify_bsize = 0;
+#endif
 
+#if !CONFIG_CB4X4
   assert(bsize >= BLOCK_8X8);
+#endif
 
   if (mi_row >= cm->mi_rows || mi_col >= cm->mi_cols) return 1;
 
@@ -5592,18 +5612,18 @@
     case PARTITION_NONE: return check_intra_b(&pc_tree->none); break;
     case PARTITION_VERT:
       if (check_intra_b(&pc_tree->vertical[0])) return 1;
-      if (mi_col + hbs < cm->mi_cols && bsize > BLOCK_8X8) {
+      if (mi_col + hbs < cm->mi_cols && (bsize > BLOCK_8X8 || unify_bsize)) {
         if (check_intra_b(&pc_tree->vertical[1])) return 1;
       }
       break;
     case PARTITION_HORZ:
       if (check_intra_b(&pc_tree->horizontal[0])) return 1;
-      if (mi_row + hbs < cm->mi_rows && bsize > BLOCK_8X8) {
+      if (mi_row + hbs < cm->mi_rows && (bsize > BLOCK_8X8 || unify_bsize)) {
         if (check_intra_b(&pc_tree->horizontal[1])) return 1;
       }
       break;
     case PARTITION_SPLIT:
-      if (bsize == BLOCK_8X8) {
+      if (bsize == BLOCK_8X8 && !unify_bsize) {
         if (check_intra_b(pc_tree->leaf_split[0])) return 1;
       } else {
         if (check_intra_sb(cpi, tile, mi_row, mi_col, subsize,
@@ -5655,6 +5675,11 @@
                             PC_TREE *pc_tree) {
   PARTITION_TYPE partition;
   BLOCK_SIZE subsize;
+#if CONFIG_CB4X4
+  const int unify_bsize = 1;
+#else
+  const int unify_bsize = 0;
+#endif
 
   partition = pc_tree->partitioning;
   subsize = get_subsize(bsize, partition);
@@ -5665,7 +5690,7 @@
     case PARTITION_HORZ:
       return check_supertx_b(supertx_size, &pc_tree->horizontal[0]);
     case PARTITION_SPLIT:
-      if (bsize == BLOCK_8X8)
+      if (bsize == BLOCK_8X8 && !unify_bsize)
         return check_supertx_b(supertx_size, pc_tree->leaf_split[0]);
       else
         return check_supertx_sb(subsize, supertx_size, pc_tree->split[0]);
@@ -5790,7 +5815,12 @@
   const int mi_height = mi_size_high[bsize];
   int xss = xd->plane[1].subsampling_x;
   int yss = xd->plane[1].subsampling_y;
-  int b_sub8x8 = (bsize < BLOCK_8X8) ? 1 : 0;
+#if CONFIG_CB4X4
+  const int unify_bsize = 1;
+#else
+  const int unify_bsize = 0;
+#endif
+  int b_sub8x8 = (bsize < BLOCK_8X8) && !unify_bsize ? 1 : 0;
 
   BLOCK_SIZE extend_bsize;
   int unit, mi_row_pred, mi_col_pred;
@@ -5895,8 +5925,10 @@
   const AV1_COMMON *const cm = &cpi->common;
   MACROBLOCK *const x = &td->mb;
   MACROBLOCKD *const xd = &x->e_mbd;
-
-  const int ctx = partition_plane_context(xd, mi_row, mi_col, bsize);
+  const int is_partition_root = bsize >= BLOCK_8X8;
+  const int ctx = is_partition_root
+                      ? partition_plane_context(xd, mi_row, mi_col, bsize)
+                      : 0;
   const int hbs = mi_size_wide[bsize] / 2;
   const PARTITION_TYPE partition = pc_tree->partitioning;
   const BLOCK_SIZE subsize = get_subsize(bsize, partition);
@@ -5912,8 +5944,12 @@
   int dst_stride1[3] = { MAX_TX_SIZE, MAX_TX_SIZE, MAX_TX_SIZE };
   int dst_stride2[3] = { MAX_TX_SIZE, MAX_TX_SIZE, MAX_TX_SIZE };
   int dst_stride3[3] = { MAX_TX_SIZE, MAX_TX_SIZE, MAX_TX_SIZE };
-
+#if CONFIG_CB4X4
+  const int unify_bsize = 1;
+#else
+  const int unify_bsize = 0;
   assert(bsize >= BLOCK_8X8);
+#endif
 
   if (mi_row >= cm->mi_rows || mi_col >= cm->mi_cols) return;
 
@@ -5944,7 +5980,7 @@
   }
 #endif  // CONFIG_AOM_HIGHBITDEPTH
 
-  if (!dry_run && bsize < top_bsize) {
+  if (!dry_run && is_partition_root && bsize < top_bsize) {
     // Explicitly cast away const.
     FRAME_COUNTS *const frame_counts = (FRAME_COUNTS *)&cm->counts;
     frame_counts->partition[ctx][partition]++;
@@ -5965,7 +6001,7 @@
                  mi_col_top, dry_run, dst_buf, dst_stride);
       break;
     case PARTITION_HORZ:
-      if (bsize == BLOCK_8X8) {
+      if (bsize == BLOCK_8X8 && !unify_bsize) {
         // Fisrt half
         predict_b_extend(cpi, td, tile, 0, mi_row, mi_col, mi_row, mi_col,
                          mi_row_top, mi_col_top, dst_buf, dst_stride, top_bsize,
@@ -6028,7 +6064,7 @@
       }
       break;
     case PARTITION_VERT:
-      if (bsize == BLOCK_8X8) {
+      if (bsize == BLOCK_8X8 && !unify_bsize) {
         // First half
         predict_b_extend(cpi, td, tile, 0, mi_row, mi_col, mi_row, mi_col,
                          mi_row_top, mi_col_top, dst_buf, dst_stride, top_bsize,
@@ -6089,7 +6125,7 @@
       }
       break;
     case PARTITION_SPLIT:
-      if (bsize == BLOCK_8X8) {
+      if (bsize == BLOCK_8X8 && !unify_bsize) {
         predict_b_extend(cpi, td, tile, 0, mi_row, mi_col, mi_row, mi_col,
                          mi_row_top, mi_col_top, dst_buf, dst_stride, top_bsize,
                          BLOCK_8X8, dry_run, 1, 0);
@@ -6132,7 +6168,7 @@
                              pc_tree->split[3]);
       }
       for (i = 0; i < MAX_MB_PLANE; i++) {
-        if (bsize == BLOCK_8X8 && i != 0)
+        if (bsize == BLOCK_8X8 && i != 0 && !unify_bsize)
           continue;  // Skip <4x4 chroma smoothing
         if (mi_row < cm->mi_rows && mi_col + hbs < cm->mi_cols) {
           av1_build_masked_inter_predictor_complex(