supertx: code refactoring + resolve conflicts with baseline

Refactoring: split prediction+extension for each plane, so we can
handle luma/chroma supertx pred in different ways.
Compatibility fix: fix conflicts with cb4x4 and chroma_sub8x8, now
for chroma sub8x8 supertx, only the top-left(basic cb4x4) or the
the bottom-right(cb4x4 + chroma_sub8x8) predictor will be used
without any blending within a 8x8 unit.

Change-Id: I6cf7b12768a82d3c7e01811ada02de84af9bd8ac
diff --git a/av1/common/onyxc_int.h b/av1/common/onyxc_int.h
index f1e3a2e..ea265ec 100644
--- a/av1/common/onyxc_int.h
+++ b/av1/common/onyxc_int.h
@@ -758,6 +758,20 @@
 #endif
 }
 
+#if CONFIG_SUPERTX
+static INLINE int need_handle_chroma_sub8x8(BLOCK_SIZE bsize, int subsampling_x,
+                                            int subsampling_y) {
+  const int bw = mi_size_wide[bsize];
+  const int bh = mi_size_high[bsize];
+
+  if (bsize >= BLOCK_8X8 ||
+      ((!(bh & 0x01) || !subsampling_y) && (!(bw & 0x01) || !subsampling_x)))
+    return 0;
+  else
+    return 1;
+}
+#endif
+
 static INLINE BLOCK_SIZE scale_chroma_bsize(BLOCK_SIZE bsize, int subsampling_x,
                                             int subsampling_y) {
   BLOCK_SIZE bs = bsize;
diff --git a/av1/common/reconinter.c b/av1/common/reconinter.c
index d62d3ab..a0d1b63 100644
--- a/av1/common/reconinter.c
+++ b/av1/common/reconinter.c
@@ -1580,19 +1580,19 @@
   } while (--h_remain);
 }
 
-void av1_build_inter_predictors_sb_sub8x8_extend(const AV1_COMMON *cm,
-                                                 MACROBLOCKD *xd,
+void av1_build_inter_predictor_sb_sub8x8_extend(const AV1_COMMON *cm,
+                                                MACROBLOCKD *xd,
 #if CONFIG_EXT_INTER
-                                                 int mi_row_ori, int mi_col_ori,
+                                                int mi_row_ori, int mi_col_ori,
 #endif  // CONFIG_EXT_INTER
-                                                 int mi_row, int mi_col,
-                                                 BLOCK_SIZE bsize, int block) {
+                                                int mi_row, int mi_col,
+                                                int plane, BLOCK_SIZE bsize,
+                                                int block) {
   // Prediction function used in supertx:
   // Use the mv at current block (which is less than 8x8)
   // to get prediction of a block located at (mi_row, mi_col) at size of bsize
   // bsize can be larger than 8x8.
   // block (0-3): the sub8x8 location of current block
-  int plane;
   const int mi_x = mi_col * MI_SIZE;
   const int mi_y = mi_row * MI_SIZE;
 #if CONFIG_EXT_INTER
@@ -1603,68 +1603,50 @@
   // For sub8x8 uv:
   // Skip uv prediction in supertx except the first block (block = 0)
   int max_plane = block ? 1 : MAX_MB_PLANE;
+  if (plane >= max_plane) return;
 
-  for (plane = 0; plane < max_plane; plane++) {
-    const BLOCK_SIZE plane_bsize =
-        get_plane_block_size(bsize, &xd->plane[plane]);
-    const int num_4x4_w = num_4x4_blocks_wide_lookup[plane_bsize];
-    const int num_4x4_h = num_4x4_blocks_high_lookup[plane_bsize];
-    const int bw = 4 * num_4x4_w;
-    const int bh = 4 * num_4x4_h;
+  const BLOCK_SIZE plane_bsize = get_plane_block_size(bsize, &xd->plane[plane]);
+  const int num_4x4_w = num_4x4_blocks_wide_lookup[plane_bsize];
+  const int num_4x4_h = num_4x4_blocks_high_lookup[plane_bsize];
+  const int bw = 4 * num_4x4_w;
+  const int bh = 4 * num_4x4_h;
 
-    build_inter_predictors(cm, xd, plane,
+  build_inter_predictors(cm, xd, plane,
 #if CONFIG_MOTION_VAR
-                           0, 0,
+                         0, 0,
 #endif  // CONFIG_MOTION_VAR
-                           block, bw, bh, 0, 0, bw, bh,
+                         block, bw, bh, 0, 0, bw, bh,
 #if CONFIG_EXT_INTER
-                           wedge_offset_x, wedge_offset_y,
+                         wedge_offset_x, wedge_offset_y,
 #endif  // CONFIG_EXT_INTER
-                           mi_x, mi_y);
-  }
-#if CONFIG_EXT_INTER
-  if (is_interintra_pred(&xd->mi[0]->mbmi)) {
-    BUFFER_SET ctx = { { xd->plane[0].dst.buf, xd->plane[1].dst.buf,
-                         xd->plane[2].dst.buf },
-                       { xd->plane[0].dst.stride, xd->plane[1].dst.stride,
-                         xd->plane[2].dst.stride } };
-    av1_build_interintra_predictors(
-        xd, xd->plane[0].dst.buf, xd->plane[1].dst.buf, xd->plane[2].dst.buf,
-        xd->plane[0].dst.stride, xd->plane[1].dst.stride,
-        xd->plane[2].dst.stride, &ctx, bsize);
-  }
-#endif  // CONFIG_EXT_INTER
+                         mi_x, mi_y);
 }
 
-void av1_build_inter_predictors_sb_extend(const AV1_COMMON *cm, MACROBLOCKD *xd,
+void av1_build_inter_predictor_sb_extend(const AV1_COMMON *cm, MACROBLOCKD *xd,
 #if CONFIG_EXT_INTER
-                                          int mi_row_ori, int mi_col_ori,
+                                         int mi_row_ori, int mi_col_ori,
 #endif  // CONFIG_EXT_INTER
-                                          int mi_row, int mi_col,
-                                          BLOCK_SIZE bsize) {
-  int plane;
+                                         int mi_row, int mi_col, int plane,
+                                         BLOCK_SIZE bsize) {
   const int mi_x = mi_col * MI_SIZE;
   const int mi_y = mi_row * MI_SIZE;
 #if CONFIG_EXT_INTER
   const int wedge_offset_x = (mi_col_ori - mi_col) * MI_SIZE;
   const int wedge_offset_y = (mi_row_ori - mi_row) * MI_SIZE;
 #endif  // CONFIG_EXT_INTER
-  for (plane = 0; plane < MAX_MB_PLANE; ++plane) {
-    const BLOCK_SIZE plane_bsize =
-        get_plane_block_size(bsize, &xd->plane[plane]);
-    const int bw = block_size_wide[plane_bsize];
-    const int bh = block_size_high[plane_bsize];
+  const BLOCK_SIZE plane_bsize = get_plane_block_size(bsize, &xd->plane[plane]);
+  const int bw = block_size_wide[plane_bsize];
+  const int bh = block_size_high[plane_bsize];
 
-    build_inter_predictors(cm, xd, plane,
+  build_inter_predictors(cm, xd, plane,
 #if CONFIG_MOTION_VAR
-                           0, 0,
+                         0, 0,
 #endif  // CONFIG_MOTION_VAR
-                           0, bw, bh, 0, 0, bw, bh,
+                         0, bw, bh, 0, 0, bw, bh,
 #if CONFIG_EXT_INTER
-                           wedge_offset_x, wedge_offset_y,
+                         wedge_offset_x, wedge_offset_y,
 #endif  // CONFIG_EXT_INTER
-                           mi_x, mi_y);
-  }
+                         mi_x, mi_y);
 }
 #endif  // CONFIG_SUPERTX
 
diff --git a/av1/common/reconinter.h b/av1/common/reconinter.h
index dfb1e7d..9465821 100644
--- a/av1/common/reconinter.h
+++ b/av1/common/reconinter.h
@@ -547,20 +547,21 @@
                                    BLOCK_SIZE bsize);
 
 #if CONFIG_SUPERTX
-void av1_build_inter_predictors_sb_sub8x8_extend(const AV1_COMMON *cm,
-                                                 MACROBLOCKD *xd,
+void av1_build_inter_predictor_sb_sub8x8_extend(const AV1_COMMON *cm,
+                                                MACROBLOCKD *xd,
 #if CONFIG_EXT_INTER
-                                                 int mi_row_ori, int mi_col_ori,
+                                                int mi_row_ori, int mi_col_ori,
 #endif  // CONFIG_EXT_INTER
-                                                 int mi_row, int mi_col,
-                                                 BLOCK_SIZE bsize, int block);
+                                                int mi_row, int mi_col,
+                                                int plane, BLOCK_SIZE bsize,
+                                                int block);
 
-void av1_build_inter_predictors_sb_extend(const AV1_COMMON *cm, MACROBLOCKD *xd,
+void av1_build_inter_predictor_sb_extend(const AV1_COMMON *cm, MACROBLOCKD *xd,
 #if CONFIG_EXT_INTER
-                                          int mi_row_ori, int mi_col_ori,
+                                         int mi_row_ori, int mi_col_ori,
 #endif  // CONFIG_EXT_INTER
-                                          int mi_row, int mi_col,
-                                          BLOCK_SIZE bsize);
+                                         int mi_row, int mi_col, int plane,
+                                         BLOCK_SIZE bsize);
 struct macroblockd_plane;
 void av1_build_masked_inter_predictor_complex(
     MACROBLOCKD *xd, uint8_t *dst, int dst_stride, const uint8_t *pre,
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index 42947f6..d7749dd 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -972,7 +972,7 @@
 static void dec_predict_b_extend(
     AV1Decoder *const pbi, MACROBLOCKD *const xd, const TileInfo *const tile,
     int block, int mi_row_ori, int mi_col_ori, int mi_row_pred, int mi_col_pred,
-    int mi_row_top, int mi_col_top, uint8_t *dst_buf[3], int dst_stride[3],
+    int mi_row_top, int mi_col_top, int plane, uint8_t *dst_buf, int dst_stride,
     BLOCK_SIZE bsize_top, BLOCK_SIZE bsize_pred, int b_sub8x8, int bextend) {
   // Used in supertx
   // (mi_row_ori, mi_col_ori): location for mv
@@ -1002,39 +1002,34 @@
 
   if (!bextend) mbmi->tx_size = max_txsize_lookup[bsize_top];
 
-  xd->plane[0].dst.stride = dst_stride[0];
-  xd->plane[1].dst.stride = dst_stride[1];
-  xd->plane[2].dst.stride = dst_stride[2];
-  xd->plane[0].dst.buf = dst_buf[0] +
-                         (r >> xd->plane[0].subsampling_y) * dst_stride[0] +
-                         (c >> xd->plane[0].subsampling_x);
-  xd->plane[1].dst.buf = dst_buf[1] +
-                         (r >> xd->plane[1].subsampling_y) * dst_stride[1] +
-                         (c >> xd->plane[1].subsampling_x);
-  xd->plane[2].dst.buf = dst_buf[2] +
-                         (r >> xd->plane[2].subsampling_y) * dst_stride[2] +
-                         (c >> xd->plane[2].subsampling_x);
+  xd->plane[plane].dst.stride = dst_stride;
+  xd->plane[plane].dst.buf =
+      dst_buf + (r >> xd->plane[plane].subsampling_y) * dst_stride +
+      (c >> xd->plane[plane].subsampling_x);
 
   if (!b_sub8x8)
-    av1_build_inter_predictors_sb_extend(&pbi->common, xd,
+    av1_build_inter_predictor_sb_extend(&pbi->common, xd,
 #if CONFIG_EXT_INTER
-                                         mi_row_ori, mi_col_ori,
+                                        mi_row_ori, mi_col_ori,
 #endif  // CONFIG_EXT_INTER
-                                         mi_row_pred, mi_col_pred, bsize_pred);
+                                        mi_row_pred, mi_col_pred, plane,
+                                        bsize_pred);
   else
-    av1_build_inter_predictors_sb_sub8x8_extend(&pbi->common, xd,
+    av1_build_inter_predictor_sb_sub8x8_extend(&pbi->common, xd,
 #if CONFIG_EXT_INTER
-                                                mi_row_ori, mi_col_ori,
+                                               mi_row_ori, mi_col_ori,
 #endif  // CONFIG_EXT_INTER
-                                                mi_row_pred, mi_col_pred,
-                                                bsize_pred, block);
+                                               mi_row_pred, mi_col_pred, plane,
+                                               bsize_pred, block);
 }
 
 static void dec_extend_dir(AV1Decoder *const pbi, MACROBLOCKD *const xd,
                            const TileInfo *const tile, int block,
-                           BLOCK_SIZE bsize, BLOCK_SIZE top_bsize, int mi_row,
+                           BLOCK_SIZE bsize, BLOCK_SIZE top_bsize,
+                           int mi_row_ori, int mi_col_ori, int mi_row,
                            int mi_col, int mi_row_top, int mi_col_top,
-                           uint8_t *dst_buf[3], int dst_stride[3], int dir) {
+                           int plane, uint8_t *dst_buf, int dst_stride,
+                           int dir) {
   // dir: 0-lower, 1-upper, 2-left, 3-right
   //      4-lowerleft, 5-upperleft, 6-lowerright, 7-upperright
   const int mi_width = mi_size_wide[bsize];
@@ -1074,9 +1069,9 @@
 
     for (j = 0; j < mi_height + ext_offset; j += high_unit)
       for (i = 0; i < mi_width + ext_offset; i += wide_unit)
-        dec_predict_b_extend(pbi, xd, tile, block, mi_row, mi_col,
+        dec_predict_b_extend(pbi, xd, tile, block, mi_row_ori, mi_col_ori,
                              mi_row_pred + j, mi_col_pred + i, mi_row_top,
-                             mi_col_top, dst_buf, dst_stride, top_bsize,
+                             mi_col_top, plane, dst_buf, dst_stride, top_bsize,
                              extend_bsize, b_sub8x8, 1);
   } else if (dir == 2 || dir == 3) {
     extend_bsize =
@@ -1098,9 +1093,9 @@
 
     for (j = 0; j < mi_height + ext_offset; j += high_unit)
       for (i = 0; i < mi_width + ext_offset; i += wide_unit)
-        dec_predict_b_extend(pbi, xd, tile, block, mi_row, mi_col,
+        dec_predict_b_extend(pbi, xd, tile, block, mi_row_ori, mi_col_ori,
                              mi_row_pred + j, mi_col_pred + i, mi_row_top,
-                             mi_col_top, dst_buf, dst_stride, top_bsize,
+                             mi_col_top, plane, dst_buf, dst_stride, top_bsize,
                              extend_bsize, b_sub8x8, 1);
   } else {
     extend_bsize = BLOCK_8X8;
@@ -1120,21 +1115,23 @@
 
     for (j = 0; j < mi_height + ext_offset; j += high_unit)
       for (i = 0; i < mi_width + ext_offset; i += wide_unit)
-        dec_predict_b_extend(pbi, xd, tile, block, mi_row, mi_col,
+        dec_predict_b_extend(pbi, xd, tile, block, mi_row_ori, mi_col_ori,
                              mi_row_pred + j, mi_col_pred + i, mi_row_top,
-                             mi_col_top, dst_buf, dst_stride, top_bsize,
+                             mi_col_top, plane, dst_buf, dst_stride, top_bsize,
                              extend_bsize, b_sub8x8, 1);
   }
 }
 
 static void dec_extend_all(AV1Decoder *const pbi, MACROBLOCKD *const xd,
                            const TileInfo *const tile, int block,
-                           BLOCK_SIZE bsize, BLOCK_SIZE top_bsize, int mi_row,
+                           BLOCK_SIZE bsize, BLOCK_SIZE top_bsize,
+                           int mi_row_ori, int mi_col_ori, int mi_row,
                            int mi_col, int mi_row_top, int mi_col_top,
-                           uint8_t *dst_buf[3], int dst_stride[3]) {
+                           int plane, uint8_t *dst_buf, int dst_stride) {
   for (int i = 0; i < 8; ++i) {
-    dec_extend_dir(pbi, xd, tile, block, bsize, top_bsize, mi_row, mi_col,
-                   mi_row_top, mi_col_top, dst_buf, dst_stride, i);
+    dec_extend_dir(pbi, xd, tile, block, bsize, top_bsize, mi_row_ori,
+                   mi_col_ori, mi_row, mi_col, mi_row_top, mi_col_top, plane,
+                   dst_buf, dst_stride, i);
   }
 }
 
@@ -1206,30 +1203,37 @@
   switch (partition) {
     case PARTITION_NONE:
       assert(bsize < top_bsize);
-      dec_predict_b_extend(pbi, xd, tile, 0, mi_row, mi_col, mi_row, mi_col,
-                           mi_row_top, mi_col_top, dst_buf, dst_stride,
-                           top_bsize, bsize, 0, 0);
-      dec_extend_all(pbi, xd, tile, 0, bsize, top_bsize, mi_row, mi_col,
-                     mi_row_top, mi_col_top, dst_buf, dst_stride);
+      for (i = 0; i < MAX_MB_PLANE; i++) {
+        dec_predict_b_extend(pbi, xd, tile, 0, mi_row, mi_col, mi_row, mi_col,
+                             mi_row_top, mi_col_top, i, dst_buf[i],
+                             dst_stride[i], top_bsize, bsize, 0, 0);
+        dec_extend_all(pbi, xd, tile, 0, bsize, top_bsize, mi_row, mi_col,
+                       mi_row, mi_col, mi_row_top, mi_col_top, i, dst_buf[i],
+                       dst_stride[i]);
+      }
       break;
     case PARTITION_HORZ:
       if (bsize == BLOCK_8X8 && !unify_bsize) {
-        // For sub8x8, predict in 8x8 unit
-        // First half
-        dec_predict_b_extend(pbi, xd, tile, 0, mi_row, mi_col, mi_row, mi_col,
-                             mi_row_top, mi_col_top, dst_buf, dst_stride,
-                             top_bsize, BLOCK_8X8, 1, 0);
-        if (bsize < top_bsize)
-          dec_extend_all(pbi, xd, tile, 0, subsize, top_bsize, mi_row, mi_col,
-                         mi_row_top, mi_col_top, dst_buf, dst_stride);
+        for (i = 0; i < MAX_MB_PLANE; i++) {
+          // For sub8x8, predict in 8x8 unit
+          // First half
+          dec_predict_b_extend(pbi, xd, tile, 0, mi_row, mi_col, mi_row, mi_col,
+                               mi_row_top, mi_col_top, i, dst_buf[i],
+                               dst_stride[i], top_bsize, BLOCK_8X8, 1, 0);
+          if (bsize < top_bsize)
+            dec_extend_all(pbi, xd, tile, 0, subsize, top_bsize, mi_row, mi_col,
+                           mi_row, mi_col, mi_row_top, mi_col_top, i,
+                           dst_buf[i], dst_stride[i]);
 
-        // Second half
-        dec_predict_b_extend(pbi, xd, tile, 2, mi_row, mi_col, mi_row, mi_col,
-                             mi_row_top, mi_col_top, dst_buf1, dst_stride1,
-                             top_bsize, BLOCK_8X8, 1, 1);
-        if (bsize < top_bsize)
-          dec_extend_all(pbi, xd, tile, 2, subsize, top_bsize, mi_row, mi_col,
-                         mi_row_top, mi_col_top, dst_buf1, dst_stride1);
+          // Second half
+          dec_predict_b_extend(pbi, xd, tile, 2, mi_row, mi_col, mi_row, mi_col,
+                               mi_row_top, mi_col_top, i, dst_buf1[i],
+                               dst_stride1[i], top_bsize, BLOCK_8X8, 1, 1);
+          if (bsize < top_bsize)
+            dec_extend_all(pbi, xd, tile, 2, subsize, top_bsize, mi_row, mi_col,
+                           mi_row, mi_col, mi_row_top, mi_col_top, i,
+                           dst_buf1[i], dst_stride1[i]);
+        }
 
         // weighted average to smooth the boundary
         xd->plane[0].dst.buf = dst_buf[0];
@@ -1239,60 +1243,91 @@
             mi_col, mi_row_top, mi_col_top, bsize, top_bsize, PARTITION_HORZ,
             0);
       } else {
-        // First half
-        dec_predict_b_extend(pbi, xd, tile, 0, mi_row, mi_col, mi_row, mi_col,
-                             mi_row_top, mi_col_top, dst_buf, dst_stride,
-                             top_bsize, subsize, 0, 0);
-        if (bsize < top_bsize)
-          dec_extend_all(pbi, xd, tile, 0, subsize, top_bsize, mi_row, mi_col,
-                         mi_row_top, mi_col_top, dst_buf, dst_stride);
-        else
-          dec_extend_dir(pbi, xd, tile, 0, subsize, top_bsize, mi_row, mi_col,
-                         mi_row_top, mi_col_top, dst_buf, dst_stride, 0);
+        for (i = 0; i < MAX_MB_PLANE; i++) {
+#if CONFIG_CB4X4
+          const struct macroblockd_plane *pd = &xd->plane[i];
+          int handle_chroma_sub8x8 = need_handle_chroma_sub8x8(
+              subsize, pd->subsampling_x, pd->subsampling_y);
 
-        if (mi_row + hbs < cm->mi_rows) {
-          // Second half
-          dec_predict_b_extend(pbi, xd, tile, 0, mi_row + hbs, mi_col,
-                               mi_row + hbs, mi_col, mi_row_top, mi_col_top,
-                               dst_buf1, dst_stride1, top_bsize, subsize, 0, 0);
-          if (bsize < top_bsize)
-            dec_extend_all(pbi, xd, tile, 0, subsize, top_bsize, mi_row + hbs,
-                           mi_col, mi_row_top, mi_col_top, dst_buf1,
-                           dst_stride1);
-          else
-            dec_extend_dir(pbi, xd, tile, 0, subsize, top_bsize, mi_row + hbs,
-                           mi_col, mi_row_top, mi_col_top, dst_buf1,
-                           dst_stride1, 1);
+          if (handle_chroma_sub8x8) {
+            int mode_offset_row = CONFIG_CHROMA_SUB8X8 ? hbs : 0;
 
-          // weighted average to smooth the boundary
-          for (i = 0; i < MAX_MB_PLANE; i++) {
-            xd->plane[i].dst.buf = dst_buf[i];
-            xd->plane[i].dst.stride = dst_stride[i];
-            av1_build_masked_inter_predictor_complex(
-                xd, dst_buf[i], dst_stride[i], dst_buf1[i], dst_stride1[i],
-                mi_row, mi_col, mi_row_top, mi_col_top, bsize, top_bsize,
-                PARTITION_HORZ, i);
+            dec_predict_b_extend(pbi, xd, tile, 0, mi_row + mode_offset_row,
+                                 mi_col, mi_row, mi_col, mi_row_top, mi_col_top,
+                                 i, dst_buf[i], dst_stride[i], top_bsize, bsize,
+                                 0, 0);
+            if (bsize < top_bsize)
+              dec_extend_all(pbi, xd, tile, 0, bsize, top_bsize,
+                             mi_row + mode_offset_row, mi_col, mi_row, mi_col,
+                             mi_row_top, mi_col_top, i, dst_buf[i],
+                             dst_stride[i]);
+          } else {
+#endif
+            // First half
+            dec_predict_b_extend(pbi, xd, tile, 0, mi_row, mi_col, mi_row,
+                                 mi_col, mi_row_top, mi_col_top, i, dst_buf[i],
+                                 dst_stride[i], top_bsize, subsize, 0, 0);
+            if (bsize < top_bsize)
+              dec_extend_all(pbi, xd, tile, 0, subsize, top_bsize, mi_row,
+                             mi_col, mi_row, mi_col, mi_row_top, mi_col_top, i,
+                             dst_buf[i], dst_stride[i]);
+            else
+              dec_extend_dir(pbi, xd, tile, 0, subsize, top_bsize, mi_row,
+                             mi_col, mi_row, mi_col, mi_row_top, mi_col_top, i,
+                             dst_buf[i], dst_stride[i], 0);
+
+            if (mi_row + hbs < cm->mi_rows) {
+              // Second half
+              dec_predict_b_extend(pbi, xd, tile, 0, mi_row + hbs, mi_col,
+                                   mi_row + hbs, mi_col, mi_row_top, mi_col_top,
+                                   i, dst_buf1[i], dst_stride1[i], top_bsize,
+                                   subsize, 0, 0);
+              if (bsize < top_bsize)
+                dec_extend_all(pbi, xd, tile, 0, subsize, top_bsize,
+                               mi_row + hbs, mi_col, mi_row + hbs, mi_col,
+                               mi_row_top, mi_col_top, i, dst_buf1[i],
+                               dst_stride1[i]);
+              else
+                dec_extend_dir(pbi, xd, tile, 0, subsize, top_bsize,
+                               mi_row + hbs, mi_col, mi_row + hbs, mi_col,
+                               mi_row_top, mi_col_top, i, dst_buf1[i],
+                               dst_stride1[i], 1);
+
+              // weighted average to smooth the boundary
+              xd->plane[i].dst.buf = dst_buf[i];
+              xd->plane[i].dst.stride = dst_stride[i];
+              av1_build_masked_inter_predictor_complex(
+                  xd, dst_buf[i], dst_stride[i], dst_buf1[i], dst_stride1[i],
+                  mi_row, mi_col, mi_row_top, mi_col_top, bsize, top_bsize,
+                  PARTITION_HORZ, i);
+            }
+#if CONFIG_CB4X4
           }
+#endif
         }
       }
       break;
     case PARTITION_VERT:
       if (bsize == BLOCK_8X8 && !unify_bsize) {
-        // First half
-        dec_predict_b_extend(pbi, xd, tile, 0, mi_row, mi_col, mi_row, mi_col,
-                             mi_row_top, mi_col_top, dst_buf, dst_stride,
-                             top_bsize, BLOCK_8X8, 1, 0);
-        if (bsize < top_bsize)
-          dec_extend_all(pbi, xd, tile, 0, subsize, top_bsize, mi_row, mi_col,
-                         mi_row_top, mi_col_top, dst_buf, dst_stride);
+        for (i = 0; i < MAX_MB_PLANE; i++) {
+          // First half
+          dec_predict_b_extend(pbi, xd, tile, 0, mi_row, mi_col, mi_row, mi_col,
+                               mi_row_top, mi_col_top, i, dst_buf[i],
+                               dst_stride[i], top_bsize, BLOCK_8X8, 1, 0);
+          if (bsize < top_bsize)
+            dec_extend_all(pbi, xd, tile, 0, subsize, top_bsize, mi_row, mi_col,
+                           mi_row, mi_col, mi_row_top, mi_col_top, i,
+                           dst_buf[i], dst_stride[i]);
 
-        // Second half
-        dec_predict_b_extend(pbi, xd, tile, 1, mi_row, mi_col, mi_row, mi_col,
-                             mi_row_top, mi_col_top, dst_buf1, dst_stride1,
-                             top_bsize, BLOCK_8X8, 1, 1);
-        if (bsize < top_bsize)
-          dec_extend_all(pbi, xd, tile, 1, subsize, top_bsize, mi_row, mi_col,
-                         mi_row_top, mi_col_top, dst_buf1, dst_stride1);
+          // Second half
+          dec_predict_b_extend(pbi, xd, tile, 1, mi_row, mi_col, mi_row, mi_col,
+                               mi_row_top, mi_col_top, i, dst_buf1[i],
+                               dst_stride1[i], top_bsize, BLOCK_8X8, 1, 1);
+          if (bsize < top_bsize)
+            dec_extend_all(pbi, xd, tile, 1, subsize, top_bsize, mi_row, mi_col,
+                           mi_row, mi_col, mi_row_top, mi_col_top, i,
+                           dst_buf1[i], dst_stride1[i]);
+        }
 
         // Smooth
         xd->plane[0].dst.buf = dst_buf[0];
@@ -1302,67 +1337,163 @@
             mi_col, mi_row_top, mi_col_top, bsize, top_bsize, PARTITION_VERT,
             0);
       } else {
-        // First half
-        dec_predict_b_extend(pbi, xd, tile, 0, mi_row, mi_col, mi_row, mi_col,
-                             mi_row_top, mi_col_top, dst_buf, dst_stride,
-                             top_bsize, subsize, 0, 0);
-        if (bsize < top_bsize)
-          dec_extend_all(pbi, xd, tile, 0, subsize, top_bsize, mi_row, mi_col,
-                         mi_row_top, mi_col_top, dst_buf, dst_stride);
-        else
-          dec_extend_dir(pbi, xd, tile, 0, subsize, top_bsize, mi_row, mi_col,
-                         mi_row_top, mi_col_top, dst_buf, dst_stride, 3);
+        for (i = 0; i < MAX_MB_PLANE; i++) {
+#if CONFIG_CB4X4
+          const struct macroblockd_plane *pd = &xd->plane[i];
+          int handle_chroma_sub8x8 = need_handle_chroma_sub8x8(
+              subsize, pd->subsampling_x, pd->subsampling_y);
 
-        // Second half
-        if (mi_col + hbs < cm->mi_cols) {
-          dec_predict_b_extend(pbi, xd, tile, 0, mi_row, mi_col + hbs, mi_row,
-                               mi_col + hbs, mi_row_top, mi_col_top, dst_buf1,
-                               dst_stride1, top_bsize, subsize, 0, 0);
-          if (bsize < top_bsize)
-            dec_extend_all(pbi, xd, tile, 0, subsize, top_bsize, mi_row,
-                           mi_col + hbs, mi_row_top, mi_col_top, dst_buf1,
-                           dst_stride1);
-          else
-            dec_extend_dir(pbi, xd, tile, 0, subsize, top_bsize, mi_row,
-                           mi_col + hbs, mi_row_top, mi_col_top, dst_buf1,
-                           dst_stride1, 2);
+          if (handle_chroma_sub8x8) {
+            int mode_offset_col = CONFIG_CHROMA_SUB8X8 ? hbs : 0;
+            assert(i > 0 && bsize == BLOCK_8X8);
 
-          // Smooth
-          for (i = 0; i < MAX_MB_PLANE; i++) {
-            xd->plane[i].dst.buf = dst_buf[i];
-            xd->plane[i].dst.stride = dst_stride[i];
-            av1_build_masked_inter_predictor_complex(
-                xd, dst_buf[i], dst_stride[i], dst_buf1[i], dst_stride1[i],
-                mi_row, mi_col, mi_row_top, mi_col_top, bsize, top_bsize,
-                PARTITION_VERT, i);
+            dec_predict_b_extend(pbi, xd, tile, 0, mi_row,
+                                 mi_col + mode_offset_col, mi_row, mi_col,
+                                 mi_row_top, mi_col_top, i, dst_buf[i],
+                                 dst_stride[i], top_bsize, bsize, 0, 0);
+            if (bsize < top_bsize)
+              dec_extend_all(pbi, xd, tile, 0, bsize, top_bsize, mi_row,
+                             mi_col + mode_offset_col, mi_row, mi_col,
+                             mi_row_top, mi_col_top, i, dst_buf[i],
+                             dst_stride[i]);
+          } else {
+#endif
+            // First half
+            dec_predict_b_extend(pbi, xd, tile, 0, mi_row, mi_col, mi_row,
+                                 mi_col, mi_row_top, mi_col_top, i, dst_buf[i],
+                                 dst_stride[i], top_bsize, subsize, 0, 0);
+            if (bsize < top_bsize)
+              dec_extend_all(pbi, xd, tile, 0, subsize, top_bsize, mi_row,
+                             mi_col, mi_row, mi_col, mi_row_top, mi_col_top, i,
+                             dst_buf[i], dst_stride[i]);
+            else
+              dec_extend_dir(pbi, xd, tile, 0, subsize, top_bsize, mi_row,
+                             mi_col, mi_row, mi_col, mi_row_top, mi_col_top, i,
+                             dst_buf[i], dst_stride[i], 3);
+
+            // Second half
+            if (mi_col + hbs < cm->mi_cols) {
+              dec_predict_b_extend(pbi, xd, tile, 0, mi_row, mi_col + hbs,
+                                   mi_row, mi_col + hbs, mi_row_top, mi_col_top,
+                                   i, dst_buf1[i], dst_stride1[i], top_bsize,
+                                   subsize, 0, 0);
+              if (bsize < top_bsize)
+                dec_extend_all(pbi, xd, tile, 0, subsize, top_bsize, mi_row,
+                               mi_col + hbs, mi_row, mi_col + hbs, mi_row_top,
+                               mi_col_top, i, dst_buf1[i], dst_stride1[i]);
+              else
+                dec_extend_dir(pbi, xd, tile, 0, subsize, top_bsize, mi_row,
+                               mi_col + hbs, mi_row, mi_col + hbs, mi_row_top,
+                               mi_col_top, i, dst_buf1[i], dst_stride1[i], 2);
+
+              // Smooth
+              xd->plane[i].dst.buf = dst_buf[i];
+              xd->plane[i].dst.stride = dst_stride[i];
+              av1_build_masked_inter_predictor_complex(
+                  xd, dst_buf[i], dst_stride[i], dst_buf1[i], dst_stride1[i],
+                  mi_row, mi_col, mi_row_top, mi_col_top, bsize, top_bsize,
+                  PARTITION_VERT, i);
+            }
+#if CONFIG_CB4X4
           }
+#endif
         }
       }
       break;
     case PARTITION_SPLIT:
       if (bsize == BLOCK_8X8 && !unify_bsize) {
-        dec_predict_b_extend(pbi, xd, tile, 0, mi_row, mi_col, mi_row, mi_col,
-                             mi_row_top, mi_col_top, dst_buf, dst_stride,
-                             top_bsize, BLOCK_8X8, 1, 0);
-        dec_predict_b_extend(pbi, xd, tile, 1, mi_row, mi_col, mi_row, mi_col,
-                             mi_row_top, mi_col_top, dst_buf1, dst_stride1,
-                             top_bsize, BLOCK_8X8, 1, 1);
-        dec_predict_b_extend(pbi, xd, tile, 2, mi_row, mi_col, mi_row, mi_col,
-                             mi_row_top, mi_col_top, dst_buf2, dst_stride2,
-                             top_bsize, BLOCK_8X8, 1, 1);
-        dec_predict_b_extend(pbi, xd, tile, 3, mi_row, mi_col, mi_row, mi_col,
-                             mi_row_top, mi_col_top, dst_buf3, dst_stride3,
-                             top_bsize, BLOCK_8X8, 1, 1);
-        if (bsize < top_bsize) {
-          dec_extend_all(pbi, xd, tile, 0, subsize, top_bsize, mi_row, mi_col,
-                         mi_row_top, mi_col_top, dst_buf, dst_stride);
-          dec_extend_all(pbi, xd, tile, 1, subsize, top_bsize, mi_row, mi_col,
-                         mi_row_top, mi_col_top, dst_buf1, dst_stride1);
-          dec_extend_all(pbi, xd, tile, 2, subsize, top_bsize, mi_row, mi_col,
-                         mi_row_top, mi_col_top, dst_buf2, dst_stride2);
-          dec_extend_all(pbi, xd, tile, 3, subsize, top_bsize, mi_row, mi_col,
-                         mi_row_top, mi_col_top, dst_buf3, dst_stride3);
+        for (i = 0; i < MAX_MB_PLANE; i++) {
+          dec_predict_b_extend(pbi, xd, tile, 0, mi_row, mi_col, mi_row, mi_col,
+                               mi_row_top, mi_col_top, i, dst_buf[i],
+                               dst_stride[i], top_bsize, BLOCK_8X8, 1, 0);
+          dec_predict_b_extend(pbi, xd, tile, 1, mi_row, mi_col, mi_row, mi_col,
+                               mi_row_top, mi_col_top, i, dst_buf1[i],
+                               dst_stride1[i], top_bsize, BLOCK_8X8, 1, 1);
+          dec_predict_b_extend(pbi, xd, tile, 2, mi_row, mi_col, mi_row, mi_col,
+                               mi_row_top, mi_col_top, i, dst_buf2[i],
+                               dst_stride2[i], top_bsize, BLOCK_8X8, 1, 1);
+          dec_predict_b_extend(pbi, xd, tile, 3, mi_row, mi_col, mi_row, mi_col,
+                               mi_row_top, mi_col_top, i, dst_buf3[i],
+                               dst_stride3[i], top_bsize, BLOCK_8X8, 1, 1);
+          if (bsize < top_bsize) {
+            dec_extend_all(pbi, xd, tile, 0, subsize, top_bsize, mi_row, mi_col,
+                           mi_row, mi_col, mi_row_top, mi_col_top, i,
+                           dst_buf[i], dst_stride[i]);
+            dec_extend_all(pbi, xd, tile, 1, subsize, top_bsize, mi_row, mi_col,
+                           mi_row, mi_col, mi_row_top, mi_col_top, i,
+                           dst_buf1[i], dst_stride1[i]);
+            dec_extend_all(pbi, xd, tile, 2, subsize, top_bsize, mi_row, mi_col,
+                           mi_row, mi_col, mi_row_top, mi_col_top, i,
+                           dst_buf2[i], dst_stride2[i]);
+            dec_extend_all(pbi, xd, tile, 3, subsize, top_bsize, mi_row, mi_col,
+                           mi_row, mi_col, mi_row_top, mi_col_top, i,
+                           dst_buf3[i], dst_stride3[i]);
+          }
         }
+#if CONFIG_CB4X4
+      } else if (bsize == BLOCK_8X8) {
+        for (i = 0; i < MAX_MB_PLANE; i++) {
+          const struct macroblockd_plane *pd = &xd->plane[i];
+          int handle_chroma_sub8x8 = need_handle_chroma_sub8x8(
+              subsize, pd->subsampling_x, pd->subsampling_y);
+
+          if (handle_chroma_sub8x8) {
+            int mode_offset_row =
+                CONFIG_CHROMA_SUB8X8 && mi_row + hbs < cm->mi_rows ? hbs : 0;
+            int mode_offset_col =
+                CONFIG_CHROMA_SUB8X8 && mi_col + hbs < cm->mi_cols ? hbs : 0;
+
+            dec_predict_b_extend(pbi, xd, tile, 0, mi_row + mode_offset_row,
+                                 mi_col + mode_offset_col, mi_row, mi_col,
+                                 mi_row_top, mi_col_top, i, dst_buf[i],
+                                 dst_stride[i], top_bsize, BLOCK_8X8, 0, 0);
+            if (bsize < top_bsize)
+              dec_extend_all(pbi, xd, tile, 0, BLOCK_8X8, top_bsize,
+                             mi_row + mode_offset_row, mi_col + mode_offset_col,
+                             mi_row, mi_col, mi_row_top, mi_col_top, i,
+                             dst_buf[i], dst_stride[i]);
+          } else {
+            dec_predict_b_extend(pbi, xd, tile, 0, mi_row, mi_col, mi_row,
+                                 mi_col, mi_row_top, mi_col_top, i, dst_buf[i],
+                                 dst_stride[i], top_bsize, subsize, 0, 0);
+            if (mi_row < cm->mi_rows && mi_col + hbs < cm->mi_cols)
+              dec_predict_b_extend(pbi, xd, tile, 0, mi_row, mi_col + hbs,
+                                   mi_row, mi_col + hbs, mi_row_top, mi_col_top,
+                                   i, dst_buf1[i], dst_stride1[i], top_bsize,
+                                   subsize, 0, 0);
+            if (mi_row + hbs < cm->mi_rows && mi_col < cm->mi_cols)
+              dec_predict_b_extend(pbi, xd, tile, 0, mi_row + hbs, mi_col,
+                                   mi_row + hbs, mi_col, mi_row_top, mi_col_top,
+                                   i, dst_buf2[i], dst_stride2[i], top_bsize,
+                                   subsize, 0, 0);
+            if (mi_row + hbs < cm->mi_rows && mi_col + hbs < cm->mi_cols)
+              dec_predict_b_extend(pbi, xd, tile, 0, mi_row + hbs, mi_col + hbs,
+                                   mi_row + hbs, mi_col + hbs, mi_row_top,
+                                   mi_col_top, i, dst_buf3[i], dst_stride3[i],
+                                   top_bsize, subsize, 0, 0);
+
+            if (bsize < top_bsize) {
+              dec_extend_all(pbi, xd, tile, 0, subsize, top_bsize, mi_row,
+                             mi_col, mi_row, mi_col, mi_row_top, mi_col_top, i,
+                             dst_buf[i], dst_stride[i]);
+              if (mi_row < cm->mi_rows && mi_col + hbs < cm->mi_cols)
+                dec_extend_all(pbi, xd, tile, 0, subsize, top_bsize, mi_row,
+                               mi_col + hbs, mi_row, mi_col + hbs, mi_row_top,
+                               mi_col_top, i, dst_buf1[i], dst_stride1[i]);
+              if (mi_row + hbs < cm->mi_rows && mi_col < cm->mi_cols)
+                dec_extend_all(pbi, xd, tile, 0, subsize, top_bsize,
+                               mi_row + hbs, mi_col, mi_row + hbs, mi_col,
+                               mi_row_top, mi_col_top, i, dst_buf2[i],
+                               dst_stride2[i]);
+              if (mi_row + hbs < cm->mi_rows && mi_col + hbs < cm->mi_cols)
+                dec_extend_all(pbi, xd, tile, 0, subsize, top_bsize,
+                               mi_row + hbs, mi_col + hbs, mi_row + hbs,
+                               mi_col + hbs, mi_row_top, mi_col_top, i,
+                               dst_buf3[i], dst_stride3[i]);
+            }
+          }
+        }
+#endif
       } else {
         dec_predict_sb_complex(pbi, xd, tile, mi_row, mi_col, mi_row_top,
                                mi_col_top, subsize, top_bsize, dst_buf,
@@ -1381,7 +1512,12 @@
                                  dst_buf3, dst_stride3);
       }
       for (i = 0; i < MAX_MB_PLANE; i++) {
-#if !CONFIG_CB4X4
+#if CONFIG_CB4X4
+        const struct macroblockd_plane *pd = &xd->plane[i];
+        int handle_chroma_sub8x8 = need_handle_chroma_sub8x8(
+            subsize, pd->subsampling_x, pd->subsampling_y);
+        if (handle_chroma_sub8x8) continue;  // Skip <4x4 chroma smoothing
+#else
         if (bsize == BLOCK_8X8 && i != 0)
           continue;  // Skip <4x4 chroma smoothing
 #endif
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index c8b8ec0..4e30d16 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -84,7 +84,7 @@
 #if CONFIG_EXT_INTER
                                int mi_row_ori, int mi_col_ori,
 #endif  // CONFIG_EXT_INTER
-                               int mi_row_pred, int mi_col_pred,
+                               int mi_row_pred, int mi_col_pred, int plane,
                                BLOCK_SIZE bsize_pred, int b_sub8x8, int block);
 static int check_supertx_sb(BLOCK_SIZE bsize, TX_SIZE supertx_size,
                             PC_TREE *pc_tree);
@@ -3161,7 +3161,7 @@
 #if CONFIG_SUPERTX
   int this_rate_nocoef, sum_rate_nocoef = 0, best_rate_nocoef = INT_MAX;
   int abort_flag;
-  const int supertx_allowed = !frame_is_intra_only(cm) &&
+  const int supertx_allowed = !frame_is_intra_only(cm) && bsize >= BLOCK_8X8 &&
                               bsize <= MAX_SUPERTX_BLOCK_SIZE &&
                               !xd->lossless[0];
 #endif  // CONFIG_SUPERTX
@@ -5837,7 +5837,7 @@
 #if CONFIG_EXT_INTER
                                int mi_row_ori, int mi_col_ori,
 #endif  // CONFIG_EXT_INTER
-                               int mi_row_pred, int mi_col_pred,
+                               int mi_row_pred, int mi_col_pred, int plane,
                                BLOCK_SIZE bsize_pred, int b_sub8x8, int block) {
   // Used in supertx
   // (mi_row_ori, mi_col_ori): location for mv
@@ -5860,27 +5860,28 @@
   }
 
   if (!b_sub8x8)
-    av1_build_inter_predictors_sb_extend(cm, xd,
+    av1_build_inter_predictor_sb_extend(cm, xd,
 #if CONFIG_EXT_INTER
-                                         mi_row_ori, mi_col_ori,
+                                        mi_row_ori, mi_col_ori,
 #endif  // CONFIG_EXT_INTER
-                                         mi_row_pred, mi_col_pred, bsize_pred);
+                                        mi_row_pred, mi_col_pred, plane,
+                                        bsize_pred);
   else
-    av1_build_inter_predictors_sb_sub8x8_extend(cm, xd,
+    av1_build_inter_predictor_sb_sub8x8_extend(cm, xd,
 #if CONFIG_EXT_INTER
-                                                mi_row_ori, mi_col_ori,
+                                               mi_row_ori, mi_col_ori,
 #endif  // CONFIG_EXT_INTER
-                                                mi_row_pred, mi_col_pred,
-                                                bsize_pred, block);
+                                               mi_row_pred, mi_col_pred, plane,
+                                               bsize_pred, block);
 }
 
 static void predict_b_extend(const AV1_COMP *const cpi, ThreadData *td,
                              const TileInfo *const tile, int block,
                              int mi_row_ori, int mi_col_ori, int mi_row_pred,
                              int mi_col_pred, int mi_row_top, int mi_col_top,
-                             uint8_t *dst_buf[3], int dst_stride[3],
+                             int plane, uint8_t *dst_buf, int dst_stride,
                              BLOCK_SIZE bsize_top, BLOCK_SIZE bsize_pred,
-                             RUN_TYPE dry_run, int b_sub8x8, int bextend) {
+                             RUN_TYPE dry_run, int b_sub8x8) {
   // Used in supertx
   // (mi_row_ori, mi_col_ori): location for mv
   // (mi_row_pred, mi_col_pred, bsize_pred): region to predict
@@ -5905,34 +5906,27 @@
 
   set_offsets_extend(cpi, td, tile, mi_row_pred, mi_col_pred, mi_row_ori,
                      mi_col_ori, bsize_pred);
-  xd->plane[0].dst.stride = dst_stride[0];
-  xd->plane[1].dst.stride = dst_stride[1];
-  xd->plane[2].dst.stride = dst_stride[2];
-  xd->plane[0].dst.buf = dst_buf[0] +
-                         (r >> xd->plane[0].subsampling_y) * dst_stride[0] +
-                         (c >> xd->plane[0].subsampling_x);
-  xd->plane[1].dst.buf = dst_buf[1] +
-                         (r >> xd->plane[1].subsampling_y) * dst_stride[1] +
-                         (c >> xd->plane[1].subsampling_x);
-  xd->plane[2].dst.buf = dst_buf[2] +
-                         (r >> xd->plane[2].subsampling_y) * dst_stride[2] +
-                         (c >> xd->plane[2].subsampling_x);
+  xd->plane[plane].dst.stride = dst_stride;
+  xd->plane[plane].dst.buf =
+      dst_buf + (r >> xd->plane[plane].subsampling_y) * dst_stride +
+      (c >> xd->plane[plane].subsampling_x);
 
   predict_superblock(cpi, td,
 #if CONFIG_EXT_INTER
                      mi_row_ori, mi_col_ori,
 #endif  // CONFIG_EXT_INTER
-                     mi_row_pred, mi_col_pred, bsize_pred, b_sub8x8, block);
+                     mi_row_pred, mi_col_pred, plane, bsize_pred, b_sub8x8,
+                     block);
 
-  if (!dry_run && !bextend)
+  if (!dry_run && (plane == 0) && (block == 0 || !b_sub8x8))
     update_stats(&cpi->common, td, mi_row_pred, mi_col_pred, 1);
 }
 
 static void extend_dir(const AV1_COMP *const cpi, ThreadData *td,
                        const TileInfo *const tile, int block, BLOCK_SIZE bsize,
-                       BLOCK_SIZE top_bsize, int mi_row, int mi_col,
-                       int mi_row_top, int mi_col_top, RUN_TYPE dry_run,
-                       uint8_t *dst_buf[3], int dst_stride[3], int dir) {
+                       BLOCK_SIZE top_bsize, int mi_row_ori, int mi_col_ori,
+                       int mi_row, int mi_col, int mi_row_top, int mi_col_top,
+                       int plane, uint8_t *dst_buf, int dst_stride, int dir) {
   // dir: 0-lower, 1-upper, 2-left, 3-right
   //      4-lowerleft, 5-upperleft, 6-lowerright, 7-upperright
   MACROBLOCKD *xd = &td->mb.e_mbd;
@@ -5973,10 +5967,10 @@
 
     for (j = 0; j < mi_height + ext_offset; j += high_unit)
       for (i = 0; i < mi_width + ext_offset; i += wide_unit)
-        predict_b_extend(cpi, td, tile, block, mi_row, mi_col, mi_row_pred + j,
-                         mi_col_pred + i, mi_row_top, mi_col_top, dst_buf,
-                         dst_stride, top_bsize, extend_bsize, dry_run, b_sub8x8,
-                         1);
+        predict_b_extend(cpi, td, tile, block, mi_row_ori, mi_col_ori,
+                         mi_row_pred + j, mi_col_pred + i, mi_row_top,
+                         mi_col_top, plane, dst_buf, dst_stride, top_bsize,
+                         extend_bsize, 1, b_sub8x8);
   } else if (dir == 2 || dir == 3) {  // left and right
     extend_bsize =
         (mi_height == mi_size_high[BLOCK_8X8] || bsize < BLOCK_8X8 || yss < xss)
@@ -5996,10 +5990,10 @@
 
     for (j = 0; j < mi_height + ext_offset; j += high_unit)
       for (i = 0; i < mi_width + ext_offset; i += wide_unit)
-        predict_b_extend(cpi, td, tile, block, mi_row, mi_col, mi_row_pred + j,
-                         mi_col_pred + i, mi_row_top, mi_col_top, dst_buf,
-                         dst_stride, top_bsize, extend_bsize, dry_run, b_sub8x8,
-                         1);
+        predict_b_extend(cpi, td, tile, block, mi_row_ori, mi_col_ori,
+                         mi_row_pred + j, mi_col_pred + i, mi_row_top,
+                         mi_col_top, plane, dst_buf, dst_stride, top_bsize,
+                         extend_bsize, 1, b_sub8x8);
   } else {
     extend_bsize = BLOCK_8X8;
 #if CONFIG_CB4X4
@@ -6018,35 +6012,24 @@
 
     for (j = 0; j < mi_height + ext_offset; j += high_unit)
       for (i = 0; i < mi_width + ext_offset; i += wide_unit)
-        predict_b_extend(cpi, td, tile, block, mi_row, mi_col, mi_row_pred + j,
-                         mi_col_pred + i, mi_row_top, mi_col_top, dst_buf,
-                         dst_stride, top_bsize, extend_bsize, dry_run, b_sub8x8,
-                         1);
+        predict_b_extend(cpi, td, tile, block, mi_row_ori, mi_col_ori,
+                         mi_row_pred + j, mi_col_pred + i, mi_row_top,
+                         mi_col_top, plane, dst_buf, dst_stride, top_bsize,
+                         extend_bsize, 1, b_sub8x8);
   }
 }
 
 static void extend_all(const AV1_COMP *const cpi, ThreadData *td,
                        const TileInfo *const tile, int block, BLOCK_SIZE bsize,
-                       BLOCK_SIZE top_bsize, int mi_row, int mi_col,
-                       int mi_row_top, int mi_col_top, RUN_TYPE dry_run,
-                       uint8_t *dst_buf[3], int dst_stride[3]) {
+                       BLOCK_SIZE top_bsize, int mi_row_ori, int mi_col_ori,
+                       int mi_row, int mi_col, int mi_row_top, int mi_col_top,
+                       int plane, uint8_t *dst_buf, int dst_stride) {
   assert(block >= 0 && block < 4);
-  extend_dir(cpi, td, tile, block, bsize, top_bsize, mi_row, mi_col, mi_row_top,
-             mi_col_top, dry_run, dst_buf, dst_stride, 0);
-  extend_dir(cpi, td, tile, block, bsize, top_bsize, mi_row, mi_col, mi_row_top,
-             mi_col_top, dry_run, dst_buf, dst_stride, 1);
-  extend_dir(cpi, td, tile, block, bsize, top_bsize, mi_row, mi_col, mi_row_top,
-             mi_col_top, dry_run, dst_buf, dst_stride, 2);
-  extend_dir(cpi, td, tile, block, bsize, top_bsize, mi_row, mi_col, mi_row_top,
-             mi_col_top, dry_run, dst_buf, dst_stride, 3);
-  extend_dir(cpi, td, tile, block, bsize, top_bsize, mi_row, mi_col, mi_row_top,
-             mi_col_top, dry_run, dst_buf, dst_stride, 4);
-  extend_dir(cpi, td, tile, block, bsize, top_bsize, mi_row, mi_col, mi_row_top,
-             mi_col_top, dry_run, dst_buf, dst_stride, 5);
-  extend_dir(cpi, td, tile, block, bsize, top_bsize, mi_row, mi_col, mi_row_top,
-             mi_col_top, dry_run, dst_buf, dst_stride, 6);
-  extend_dir(cpi, td, tile, block, bsize, top_bsize, mi_row, mi_col, mi_row_top,
-             mi_col_top, dry_run, dst_buf, dst_stride, 7);
+  for (int i = 0; i < 8; ++i) {
+    extend_dir(cpi, td, tile, block, bsize, top_bsize, mi_row_ori, mi_col_ori,
+               mi_row, mi_col, mi_row_top, mi_col_top, plane, dst_buf,
+               dst_stride, i);
+  }
 }
 
 // This function generates prediction for multiple blocks, between which
@@ -6140,29 +6123,36 @@
   switch (partition) {
     case PARTITION_NONE:
       assert(bsize < top_bsize);
-      predict_b_extend(cpi, td, tile, 0, mi_row, mi_col, mi_row, mi_col,
-                       mi_row_top, mi_col_top, dst_buf, dst_stride, top_bsize,
-                       bsize, dry_run, 0, 0);
-      extend_all(cpi, td, tile, 0, bsize, top_bsize, mi_row, mi_col, mi_row_top,
-                 mi_col_top, dry_run, dst_buf, dst_stride);
+      for (i = 0; i < MAX_MB_PLANE; ++i) {
+        predict_b_extend(cpi, td, tile, 0, mi_row, mi_col, mi_row, mi_col,
+                         mi_row_top, mi_col_top, i, dst_buf[i], dst_stride[i],
+                         top_bsize, bsize, dry_run, 0);
+        extend_all(cpi, td, tile, 0, bsize, top_bsize, mi_row, mi_col, mi_row,
+                   mi_col, mi_row_top, mi_col_top, i, dst_buf[i],
+                   dst_stride[i]);
+      }
       break;
     case PARTITION_HORZ:
       if (bsize == BLOCK_8X8 && !unify_bsize) {
-        // Fisrt half
-        predict_b_extend(cpi, td, tile, 0, mi_row, mi_col, mi_row, mi_col,
-                         mi_row_top, mi_col_top, dst_buf, dst_stride, top_bsize,
-                         BLOCK_8X8, dry_run, 1, 0);
-        if (bsize < top_bsize)
-          extend_all(cpi, td, tile, 0, subsize, top_bsize, mi_row, mi_col,
-                     mi_row_top, mi_col_top, dry_run, dst_buf, dst_stride);
+        for (i = 0; i < MAX_MB_PLANE; ++i) {
+          // First half
+          predict_b_extend(cpi, td, tile, 0, mi_row, mi_col, mi_row, mi_col,
+                           mi_row_top, mi_col_top, i, dst_buf[i], dst_stride[i],
+                           top_bsize, BLOCK_8X8, dry_run, 1);
+          if (bsize < top_bsize)
+            extend_all(cpi, td, tile, 0, subsize, top_bsize, mi_row, mi_col,
+                       mi_row, mi_col, mi_row_top, mi_col_top, i, dst_buf[i],
+                       dst_stride[i]);
 
-        // Second half
-        predict_b_extend(cpi, td, tile, 2, mi_row, mi_col, mi_row, mi_col,
-                         mi_row_top, mi_col_top, dst_buf1, dst_stride1,
-                         top_bsize, BLOCK_8X8, dry_run, 1, 1);
-        if (bsize < top_bsize)
-          extend_all(cpi, td, tile, 2, subsize, top_bsize, mi_row, mi_col,
-                     mi_row_top, mi_col_top, dry_run, dst_buf1, dst_stride1);
+          // Second half
+          predict_b_extend(cpi, td, tile, 2, mi_row, mi_col, mi_row, mi_col,
+                           mi_row_top, mi_col_top, i, dst_buf1[i],
+                           dst_stride1[i], top_bsize, BLOCK_8X8, dry_run, 1);
+          if (bsize < top_bsize)
+            extend_all(cpi, td, tile, 2, subsize, top_bsize, mi_row, mi_col,
+                       mi_row, mi_col, mi_row_top, mi_col_top, i, dst_buf1[i],
+                       dst_stride1[i]);
+        }
 
         // Smooth
         xd->plane[0].dst.buf = dst_buf[0];
@@ -6172,60 +6162,89 @@
             mi_col, mi_row_top, mi_col_top, bsize, top_bsize, PARTITION_HORZ,
             0);
       } else {
-        // First half
-        predict_b_extend(cpi, td, tile, 0, mi_row, mi_col, mi_row, mi_col,
-                         mi_row_top, mi_col_top, dst_buf, dst_stride, top_bsize,
-                         subsize, dry_run, 0, 0);
-        if (bsize < top_bsize)
-          extend_all(cpi, td, tile, 0, subsize, top_bsize, mi_row, mi_col,
-                     mi_row_top, mi_col_top, dry_run, dst_buf, dst_stride);
-        else
-          extend_dir(cpi, td, tile, 0, subsize, top_bsize, mi_row, mi_col,
-                     mi_row_top, mi_col_top, dry_run, dst_buf, dst_stride, 0);
+        for (i = 0; i < MAX_MB_PLANE; ++i) {
+#if CONFIG_CB4X4
+          const struct macroblockd_plane *pd = &xd->plane[i];
+          int handle_chroma_sub8x8 = need_handle_chroma_sub8x8(
+              subsize, pd->subsampling_x, pd->subsampling_y);
 
-        if (mi_row + hbs < cm->mi_rows) {
-          // Second half
-          predict_b_extend(cpi, td, tile, 0, mi_row + hbs, mi_col, mi_row + hbs,
-                           mi_col, mi_row_top, mi_col_top, dst_buf1,
-                           dst_stride1, top_bsize, subsize, dry_run, 0, 0);
-          if (bsize < top_bsize)
-            extend_all(cpi, td, tile, 0, subsize, top_bsize, mi_row + hbs,
-                       mi_col, mi_row_top, mi_col_top, dry_run, dst_buf1,
-                       dst_stride1);
-          else
-            extend_dir(cpi, td, tile, 0, subsize, top_bsize, mi_row + hbs,
-                       mi_col, mi_row_top, mi_col_top, dry_run, dst_buf1,
-                       dst_stride1, 1);
+          if (handle_chroma_sub8x8) {
+            int mode_offset_row = CONFIG_CHROMA_SUB8X8 ? hbs : 0;
 
-          // Smooth
-          for (i = 0; i < MAX_MB_PLANE; i++) {
+            predict_b_extend(cpi, td, tile, 0, mi_row + mode_offset_row, mi_col,
+                             mi_row, mi_col, mi_row_top, mi_col_top, i,
+                             dst_buf[i], dst_stride[i], top_bsize, bsize,
+                             dry_run, 0);
+            if (bsize < top_bsize)
+              extend_all(cpi, td, tile, 0, bsize, top_bsize,
+                         mi_row + mode_offset_row, mi_col, mi_row, mi_col,
+                         mi_row_top, mi_col_top, i, dst_buf[i], dst_stride[i]);
+          } else {
+#endif
+            // First half
+            predict_b_extend(cpi, td, tile, 0, mi_row, mi_col, mi_row, mi_col,
+                             mi_row_top, mi_col_top, i, dst_buf[i],
+                             dst_stride[i], top_bsize, subsize, dry_run, 0);
+            if (bsize < top_bsize)
+              extend_all(cpi, td, tile, 0, subsize, top_bsize, mi_row, mi_col,
+                         mi_row, mi_col, mi_row_top, mi_col_top, i, dst_buf[i],
+                         dst_stride[i]);
+            else
+              extend_dir(cpi, td, tile, 0, subsize, top_bsize, mi_row, mi_col,
+                         mi_row, mi_col, mi_row_top, mi_col_top, i, dst_buf[i],
+                         dst_stride[i], 0);
             xd->plane[i].dst.buf = dst_buf[i];
             xd->plane[i].dst.stride = dst_stride[i];
-            av1_build_masked_inter_predictor_complex(
-                xd, dst_buf[i], dst_stride[i], dst_buf1[i], dst_stride1[i],
-                mi_row, mi_col, mi_row_top, mi_col_top, bsize, top_bsize,
-                PARTITION_HORZ, i);
+
+            if (mi_row + hbs < cm->mi_rows) {
+              // Second half
+              predict_b_extend(cpi, td, tile, 0, mi_row + hbs, mi_col,
+                               mi_row + hbs, mi_col, mi_row_top, mi_col_top, i,
+                               dst_buf1[i], dst_stride1[i], top_bsize, subsize,
+                               dry_run, 0);
+              if (bsize < top_bsize)
+                extend_all(cpi, td, tile, 0, subsize, top_bsize, mi_row + hbs,
+                           mi_col, mi_row + hbs, mi_col, mi_row_top, mi_col_top,
+                           i, dst_buf1[i], dst_stride1[i]);
+              else
+                extend_dir(cpi, td, tile, 0, subsize, top_bsize, mi_row + hbs,
+                           mi_col, mi_row + hbs, mi_col, mi_row_top, mi_col_top,
+                           i, dst_buf1[i], dst_stride1[i], 1);
+              // Smooth
+              xd->plane[i].dst.buf = dst_buf[i];
+              xd->plane[i].dst.stride = dst_stride[i];
+              av1_build_masked_inter_predictor_complex(
+                  xd, dst_buf[i], dst_stride[i], dst_buf1[i], dst_stride1[i],
+                  mi_row, mi_col, mi_row_top, mi_col_top, bsize, top_bsize,
+                  PARTITION_HORZ, i);
+            }
+#if CONFIG_CB4X4
           }
+#endif
         }
       }
       break;
     case PARTITION_VERT:
       if (bsize == BLOCK_8X8 && !unify_bsize) {
-        // First half
-        predict_b_extend(cpi, td, tile, 0, mi_row, mi_col, mi_row, mi_col,
-                         mi_row_top, mi_col_top, dst_buf, dst_stride, top_bsize,
-                         BLOCK_8X8, dry_run, 1, 0);
-        if (bsize < top_bsize)
-          extend_all(cpi, td, tile, 0, subsize, top_bsize, mi_row, mi_col,
-                     mi_row_top, mi_col_top, dry_run, dst_buf, dst_stride);
+        for (i = 0; i < MAX_MB_PLANE; ++i) {
+          // First half
+          predict_b_extend(cpi, td, tile, 0, mi_row, mi_col, mi_row, mi_col,
+                           mi_row_top, mi_col_top, i, dst_buf[i], dst_stride[i],
+                           top_bsize, BLOCK_8X8, dry_run, 1);
+          if (bsize < top_bsize)
+            extend_all(cpi, td, tile, 0, subsize, top_bsize, mi_row, mi_col,
+                       mi_row, mi_col, mi_row_top, mi_col_top, i, dst_buf[i],
+                       dst_stride[i]);
 
-        // Second half
-        predict_b_extend(cpi, td, tile, 1, mi_row, mi_col, mi_row, mi_col,
-                         mi_row_top, mi_col_top, dst_buf1, dst_stride1,
-                         top_bsize, BLOCK_8X8, dry_run, 1, 1);
-        if (bsize < top_bsize)
-          extend_all(cpi, td, tile, 1, subsize, top_bsize, mi_row, mi_col,
-                     mi_row_top, mi_col_top, dry_run, dst_buf1, dst_stride1);
+          // Second half
+          predict_b_extend(cpi, td, tile, 1, mi_row, mi_col, mi_row, mi_col,
+                           mi_row_top, mi_col_top, i, dst_buf1[i],
+                           dst_stride1[i], top_bsize, BLOCK_8X8, dry_run, 1);
+          if (bsize < top_bsize)
+            extend_all(cpi, td, tile, 1, subsize, top_bsize, mi_row, mi_col,
+                       mi_row, mi_col, mi_row_top, mi_col_top, i, dst_buf1[i],
+                       dst_stride1[i]);
+        }
 
         // Smooth
         xd->plane[0].dst.buf = dst_buf[0];
@@ -6235,66 +6254,160 @@
             mi_col, mi_row_top, mi_col_top, bsize, top_bsize, PARTITION_VERT,
             0);
       } else {
-        // bsize: not important, not useful
-        predict_b_extend(cpi, td, tile, 0, mi_row, mi_col, mi_row, mi_col,
-                         mi_row_top, mi_col_top, dst_buf, dst_stride, top_bsize,
-                         subsize, dry_run, 0, 0);
-        if (bsize < top_bsize)
-          extend_all(cpi, td, tile, 0, subsize, top_bsize, mi_row, mi_col,
-                     mi_row_top, mi_col_top, dry_run, dst_buf, dst_stride);
-        else
-          extend_dir(cpi, td, tile, 0, subsize, top_bsize, mi_row, mi_col,
-                     mi_row_top, mi_col_top, dry_run, dst_buf, dst_stride, 3);
+        for (i = 0; i < MAX_MB_PLANE; ++i) {
+#if CONFIG_CB4X4
+          const struct macroblockd_plane *pd = &xd->plane[i];
+          int handle_chroma_sub8x8 = need_handle_chroma_sub8x8(
+              subsize, pd->subsampling_x, pd->subsampling_y);
 
-        if (mi_col + hbs < cm->mi_cols) {
-          predict_b_extend(cpi, td, tile, 0, mi_row, mi_col + hbs, mi_row,
-                           mi_col + hbs, mi_row_top, mi_col_top, dst_buf1,
-                           dst_stride1, top_bsize, subsize, dry_run, 0, 0);
-          if (bsize < top_bsize)
-            extend_all(cpi, td, tile, 0, subsize, top_bsize, mi_row,
-                       mi_col + hbs, mi_row_top, mi_col_top, dry_run, dst_buf1,
-                       dst_stride1);
-          else
-            extend_dir(cpi, td, tile, 0, subsize, top_bsize, mi_row,
-                       mi_col + hbs, mi_row_top, mi_col_top, dry_run, dst_buf1,
-                       dst_stride1, 2);
+          if (handle_chroma_sub8x8) {
+            int mode_offset_col = CONFIG_CHROMA_SUB8X8 ? hbs : 0;
 
-          for (i = 0; i < MAX_MB_PLANE; i++) {
+            predict_b_extend(cpi, td, tile, 0, mi_row, mi_col + mode_offset_col,
+                             mi_row, mi_col, mi_row_top, mi_col_top, i,
+                             dst_buf[i], dst_stride[i], top_bsize, bsize,
+                             dry_run, 0);
+            if (bsize < top_bsize)
+              extend_all(cpi, td, tile, 0, bsize, top_bsize, mi_row,
+                         mi_col + mode_offset_col, mi_row, mi_col, mi_row_top,
+                         mi_col_top, i, dst_buf[i], dst_stride[i]);
+          } else {
+#endif
+            predict_b_extend(cpi, td, tile, 0, mi_row, mi_col, mi_row, mi_col,
+                             mi_row_top, mi_col_top, i, dst_buf[i],
+                             dst_stride[i], top_bsize, subsize, dry_run, 0);
+            if (bsize < top_bsize)
+              extend_all(cpi, td, tile, 0, subsize, top_bsize, mi_row, mi_col,
+                         mi_row, mi_col, mi_row_top, mi_col_top, i, dst_buf[i],
+                         dst_stride[i]);
+            else
+              extend_dir(cpi, td, tile, 0, subsize, top_bsize, mi_row, mi_col,
+                         mi_row, mi_col, mi_row_top, mi_col_top, i, dst_buf[i],
+                         dst_stride[i], 3);
             xd->plane[i].dst.buf = dst_buf[i];
             xd->plane[i].dst.stride = dst_stride[i];
-            av1_build_masked_inter_predictor_complex(
-                xd, dst_buf[i], dst_stride[i], dst_buf1[i], dst_stride1[i],
-                mi_row, mi_col, mi_row_top, mi_col_top, bsize, top_bsize,
-                PARTITION_VERT, i);
+
+            if (mi_col + hbs < cm->mi_cols) {
+              predict_b_extend(cpi, td, tile, 0, mi_row, mi_col + hbs, mi_row,
+                               mi_col + hbs, mi_row_top, mi_col_top, i,
+                               dst_buf1[i], dst_stride1[i], top_bsize, subsize,
+                               dry_run, 0);
+              if (bsize < top_bsize)
+                extend_all(cpi, td, tile, 0, subsize, top_bsize, mi_row,
+                           mi_col + hbs, mi_row, mi_col + hbs, mi_row_top,
+                           mi_col_top, i, dst_buf1[i], dst_stride1[i]);
+              else
+                extend_dir(cpi, td, tile, 0, subsize, top_bsize, mi_row,
+                           mi_col + hbs, mi_row, mi_col + hbs, mi_row_top,
+                           mi_col_top, i, dst_buf1[i], dst_stride1[i], 2);
+
+              // smooth
+              xd->plane[i].dst.buf = dst_buf[i];
+              xd->plane[i].dst.stride = dst_stride[i];
+              av1_build_masked_inter_predictor_complex(
+                  xd, dst_buf[i], dst_stride[i], dst_buf1[i], dst_stride1[i],
+                  mi_row, mi_col, mi_row_top, mi_col_top, bsize, top_bsize,
+                  PARTITION_VERT, i);
+            }
+#if CONFIG_CB4X4
           }
+#endif
         }
       }
       break;
     case PARTITION_SPLIT:
       if (bsize == BLOCK_8X8 && !unify_bsize) {
-        predict_b_extend(cpi, td, tile, 0, mi_row, mi_col, mi_row, mi_col,
-                         mi_row_top, mi_col_top, dst_buf, dst_stride, top_bsize,
-                         BLOCK_8X8, dry_run, 1, 0);
-        predict_b_extend(cpi, td, tile, 1, mi_row, mi_col, mi_row, mi_col,
-                         mi_row_top, mi_col_top, dst_buf1, dst_stride1,
-                         top_bsize, BLOCK_8X8, dry_run, 1, 1);
-        predict_b_extend(cpi, td, tile, 2, mi_row, mi_col, mi_row, mi_col,
-                         mi_row_top, mi_col_top, dst_buf2, dst_stride2,
-                         top_bsize, BLOCK_8X8, dry_run, 1, 1);
-        predict_b_extend(cpi, td, tile, 3, mi_row, mi_col, mi_row, mi_col,
-                         mi_row_top, mi_col_top, dst_buf3, dst_stride3,
-                         top_bsize, BLOCK_8X8, dry_run, 1, 1);
+        for (i = 0; i < MAX_MB_PLANE; i++) {
+          predict_b_extend(cpi, td, tile, 0, mi_row, mi_col, mi_row, mi_col,
+                           mi_row_top, mi_col_top, i, dst_buf[i], dst_stride[i],
+                           top_bsize, BLOCK_8X8, dry_run, 1);
+          predict_b_extend(cpi, td, tile, 1, mi_row, mi_col, mi_row, mi_col,
+                           mi_row_top, mi_col_top, i, dst_buf1[i],
+                           dst_stride1[i], top_bsize, BLOCK_8X8, dry_run, 1);
+          predict_b_extend(cpi, td, tile, 2, mi_row, mi_col, mi_row, mi_col,
+                           mi_row_top, mi_col_top, i, dst_buf2[i],
+                           dst_stride2[i], top_bsize, BLOCK_8X8, dry_run, 1);
+          predict_b_extend(cpi, td, tile, 3, mi_row, mi_col, mi_row, mi_col,
+                           mi_row_top, mi_col_top, i, dst_buf3[i],
+                           dst_stride3[i], top_bsize, BLOCK_8X8, dry_run, 1);
 
-        if (bsize < top_bsize) {
-          extend_all(cpi, td, tile, 0, subsize, top_bsize, mi_row, mi_col,
-                     mi_row_top, mi_col_top, dry_run, dst_buf, dst_stride);
-          extend_all(cpi, td, tile, 1, subsize, top_bsize, mi_row, mi_col,
-                     mi_row_top, mi_col_top, dry_run, dst_buf1, dst_stride1);
-          extend_all(cpi, td, tile, 2, subsize, top_bsize, mi_row, mi_col,
-                     mi_row_top, mi_col_top, dry_run, dst_buf2, dst_stride2);
-          extend_all(cpi, td, tile, 3, subsize, top_bsize, mi_row, mi_col,
-                     mi_row_top, mi_col_top, dry_run, dst_buf3, dst_stride3);
+          if (bsize < top_bsize) {
+            extend_all(cpi, td, tile, 0, subsize, top_bsize, mi_row, mi_col,
+                       mi_row, mi_col, mi_row_top, mi_col_top, i, dst_buf[i],
+                       dst_stride[i]);
+            extend_all(cpi, td, tile, 1, subsize, top_bsize, mi_row, mi_col,
+                       mi_row, mi_col, mi_row_top, mi_col_top, i, dst_buf1[i],
+                       dst_stride1[i]);
+            extend_all(cpi, td, tile, 2, subsize, top_bsize, mi_row, mi_col,
+                       mi_row, mi_col, mi_row_top, mi_col_top, i, dst_buf2[i],
+                       dst_stride2[i]);
+            extend_all(cpi, td, tile, 3, subsize, top_bsize, mi_row, mi_col,
+                       mi_row, mi_col, mi_row_top, mi_col_top, i, dst_buf3[i],
+                       dst_stride3[i]);
+          }
         }
+#if CONFIG_CB4X4
+      } else if (bsize == BLOCK_8X8) {
+        for (i = 0; i < MAX_MB_PLANE; i++) {
+          const struct macroblockd_plane *pd = &xd->plane[i];
+          int handle_chroma_sub8x8 = need_handle_chroma_sub8x8(
+              subsize, pd->subsampling_x, pd->subsampling_y);
+
+          if (handle_chroma_sub8x8) {
+            int mode_offset_row =
+                CONFIG_CHROMA_SUB8X8 && mi_row + hbs < cm->mi_rows ? hbs : 0;
+            int mode_offset_col =
+                CONFIG_CHROMA_SUB8X8 && mi_col + hbs < cm->mi_cols ? hbs : 0;
+
+            predict_b_extend(cpi, td, tile, 0, mi_row + mode_offset_row,
+                             mi_col + mode_offset_col, mi_row, mi_col,
+                             mi_row_top, mi_col_top, i, dst_buf[i],
+                             dst_stride[i], top_bsize, BLOCK_8X8, dry_run, 0);
+            if (bsize < top_bsize)
+              extend_all(cpi, td, tile, 0, BLOCK_8X8, top_bsize,
+                         mi_row + mode_offset_row, mi_col + mode_offset_col,
+                         mi_row, mi_col, mi_row_top, mi_col_top, i, dst_buf[i],
+                         dst_stride[i]);
+          } else {
+            predict_b_extend(cpi, td, tile, 0, mi_row, mi_col, mi_row, mi_col,
+                             mi_row_top, mi_col_top, i, dst_buf[i],
+                             dst_stride[i], top_bsize, subsize, dry_run, 0);
+            if (mi_row < cm->mi_rows && mi_col + hbs < cm->mi_cols)
+              predict_b_extend(cpi, td, tile, 0, mi_row, mi_col + hbs, mi_row,
+                               mi_col + hbs, mi_row_top, mi_col_top, i,
+                               dst_buf1[i], dst_stride1[i], top_bsize, subsize,
+                               dry_run, 0);
+            if (mi_row + hbs < cm->mi_rows && mi_col < cm->mi_cols)
+              predict_b_extend(cpi, td, tile, 0, mi_row + hbs, mi_col,
+                               mi_row + hbs, mi_col, mi_row_top, mi_col_top, i,
+                               dst_buf2[i], dst_stride2[i], top_bsize, subsize,
+                               dry_run, 0);
+            if (mi_row + hbs < cm->mi_rows && mi_col + hbs < cm->mi_cols)
+              predict_b_extend(cpi, td, tile, 0, mi_row + hbs, mi_col + hbs,
+                               mi_row + hbs, mi_col + hbs, mi_row_top,
+                               mi_col_top, i, dst_buf3[i], dst_stride3[i],
+                               top_bsize, subsize, dry_run, 0);
+
+            if (bsize < top_bsize) {
+              extend_all(cpi, td, tile, 0, subsize, top_bsize, mi_row, mi_col,
+                         mi_row, mi_col, mi_row_top, mi_col_top, i, dst_buf[i],
+                         dst_stride[i]);
+              if (mi_row < cm->mi_rows && mi_col + hbs < cm->mi_cols)
+                extend_all(cpi, td, tile, 0, subsize, top_bsize, mi_row,
+                           mi_col + hbs, mi_row, mi_col + hbs, mi_row_top,
+                           mi_col_top, i, dst_buf1[i], dst_stride1[i]);
+              if (mi_row + hbs < cm->mi_rows && mi_col < cm->mi_cols)
+                extend_all(cpi, td, tile, 0, subsize, top_bsize, mi_row + hbs,
+                           mi_col, mi_row + hbs, mi_col, mi_row_top, mi_col_top,
+                           i, dst_buf2[i], dst_stride2[i]);
+              if (mi_row + hbs < cm->mi_rows && mi_col + hbs < cm->mi_cols)
+                extend_all(cpi, td, tile, 0, subsize, top_bsize, mi_row + hbs,
+                           mi_col + hbs, mi_row + hbs, mi_col + hbs, mi_row_top,
+                           mi_col_top, i, dst_buf3[i], dst_stride3[i]);
+            }
+          }
+        }
+#endif
       } else {
         predict_sb_complex(cpi, td, tile, mi_row, mi_col, mi_row_top,
                            mi_col_top, dry_run, subsize, top_bsize, dst_buf,
@@ -6314,10 +6427,16 @@
                              pc_tree->split[3]);
       }
       for (i = 0; i < MAX_MB_PLANE; i++) {
-#if !CONFIG_CB4X4
+#if CONFIG_CB4X4
+        const struct macroblockd_plane *pd = &xd->plane[i];
+        int handle_chroma_sub8x8 = need_handle_chroma_sub8x8(
+            subsize, pd->subsampling_x, pd->subsampling_y);
+        if (handle_chroma_sub8x8) continue;  // Skip <4x4 chroma smoothing
+#else
         if (bsize == BLOCK_8X8 && i != 0)
           continue;  // Skip <4x4 chroma smoothing
 #endif
+
         if (mi_row < cm->mi_rows && mi_col + hbs < cm->mi_cols) {
           av1_build_masked_inter_predictor_complex(
               xd, dst_buf[i], dst_stride[i], dst_buf1[i], dst_stride1[i],
@@ -6334,9 +6453,6 @@
                 PARTITION_HORZ, i);
           }
         } else if (mi_row + hbs < cm->mi_rows && mi_col < cm->mi_cols) {
-          if (bsize == BLOCK_8X8 && i != 0)
-            continue;  // Skip <4x4 chroma smoothing
-
           av1_build_masked_inter_predictor_complex(
               xd, dst_buf[i], dst_stride[i], dst_buf2[i], dst_stride2[i],
               mi_row, mi_col, mi_row_top, mi_col_top, bsize, top_bsize,
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 32ed8df..0c21cac 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -1455,7 +1455,9 @@
   int64_t rd1, rd2, rd;
   RD_STATS this_rd_stats;
 
+#if !CONFIG_SUPERTX
   assert(tx_size == get_tx_size(plane, xd));
+#endif  // !CONFIG_SUPERTX
 
   av1_init_rd_stats(&this_rd_stats);