Explicitly set chroma ref info based on tree_type

Currently, CHROMA_REF_INFO::is_chroma_ref does not differentiate based
on the partition tree type - it is set based on the assumption on the
current partition type is SHARED_PART. This commit explicitly set
is_chroma_ref to 0 on LUMA_TREE, and 1 to CHROMA_TREE.

This change does not affect the bitstream syntax, as the chroma
predictors check both is_chroma_ref and tree_type. However, this affects
some encoder-side optimizations (more specifically a partition tree
reuse feature), which causes a small difference in encoding output.

| CONFIG | PSNR_YUV | Enc Time | Dec Time |
| ------ | -------- | -------- | -------- |
|   AI   |  +0.02%  |    97%   |   100%   |
|   RA   |  +0.00%  |    99%   |    99%   |

STATS_CHANGED
diff --git a/av1/common/av1_common_int.h b/av1/common/av1_common_int.h
index db713a0..1f41342 100644
--- a/av1/common/av1_common_int.h
+++ b/av1/common/av1_common_int.h
@@ -2757,8 +2757,9 @@
   if (subsize < BLOCK_SIZES_ALL) {
     CHROMA_REF_INFO tmp_chroma_ref_info = { 1,      0,       mi_row,
                                             mi_col, subsize, subsize };
-    set_chroma_ref_info(mi_row, mi_col, 0, subsize, &tmp_chroma_ref_info,
-                        parent_chroma_ref_info, bsize, partition, ss_x, ss_y);
+    set_chroma_ref_info(tree_type, mi_row, mi_col, 0, subsize,
+                        &tmp_chroma_ref_info, parent_chroma_ref_info, bsize,
+                        partition, ss_x, ss_y);
     is_valid = get_plane_block_size(tmp_chroma_ref_info.bsize_base, ss_x,
                                     ss_y) != BLOCK_INVALID;
   }
diff --git a/av1/common/blockd.h b/av1/common/blockd.h
index f68a343..193e7d2 100644
--- a/av1/common/blockd.h
+++ b/av1/common/blockd.h
@@ -1378,14 +1378,23 @@
   }
 }
 
-static INLINE void set_chroma_ref_info(int mi_row, int mi_col, int index,
-                                       BLOCK_SIZE bsize, CHROMA_REF_INFO *info,
+static INLINE void set_chroma_ref_info(TREE_TYPE tree_type, int mi_row,
+                                       int mi_col, int index, BLOCK_SIZE bsize,
+                                       CHROMA_REF_INFO *info,
                                        const CHROMA_REF_INFO *parent_info,
                                        BLOCK_SIZE parent_bsize,
                                        PARTITION_TYPE parent_partition,
                                        int ss_x, int ss_y) {
   assert(bsize < BLOCK_SIZES_ALL);
   initialize_chroma_ref_info(mi_row, mi_col, bsize, info);
+  if (tree_type == LUMA_PART) {
+    info->is_chroma_ref = 0;
+    return;
+  }
+  if (tree_type == CHROMA_PART) {
+    info->is_chroma_ref = 1;
+    return;
+  }
   if (parent_info == NULL) return;
   if (parent_info->is_chroma_ref) {
     if (parent_info->offset_started) {
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index 708a020..47666bb 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -602,8 +602,8 @@
   }
 
   CHROMA_REF_INFO *chroma_ref_info = &xd->mi[0]->chroma_ref_info;
-  set_chroma_ref_info(mi_row, mi_col, index, bsize, chroma_ref_info,
-                      parent ? &parent->chroma_ref_info : NULL,
+  set_chroma_ref_info(xd->tree_type, mi_row, mi_col, index, bsize,
+                      chroma_ref_info, parent ? &parent->chroma_ref_info : NULL,
                       parent ? parent->bsize : BLOCK_INVALID,
                       parent ? parent->partition : PARTITION_NONE,
                       xd->plane[1].subsampling_x, xd->plane[1].subsampling_y);
@@ -2216,8 +2216,8 @@
 #endif  // CONFIG_CROSS_CHROMA_TX
 
   CHROMA_REF_INFO *chroma_ref_info = &xd->mi[0]->chroma_ref_info;
-  set_chroma_ref_info(mi_row, mi_col, index, bsize, chroma_ref_info,
-                      parent ? &parent->chroma_ref_info : NULL,
+  set_chroma_ref_info(xd->tree_type, mi_row, mi_col, index, bsize,
+                      chroma_ref_info, parent ? &parent->chroma_ref_info : NULL,
                       parent ? parent->bsize : BLOCK_INVALID,
                       parent ? parent->partition : PARTITION_NONE,
                       xd->plane[1].subsampling_x, xd->plane[1].subsampling_y);
@@ -2505,8 +2505,8 @@
     ptree->is_settled = 1;
     PARTITION_TREE *parent = ptree->parent;
     set_chroma_ref_info(
-        mi_row, mi_col, ptree->index, bsize, &ptree->chroma_ref_info,
-        parent ? &parent->chroma_ref_info : NULL,
+        xd->tree_type, mi_row, mi_col, ptree->index, bsize,
+        &ptree->chroma_ref_info, parent ? &parent->chroma_ref_info : NULL,
         parent ? parent->bsize : BLOCK_INVALID,
         parent ? parent->partition : PARTITION_NONE, ss_x, ss_y);
 
@@ -2573,7 +2573,8 @@
     const int index =
         (partition == PARTITION_HORZ || partition == PARTITION_VERT) +
         (partition == PARTITION_HORZ_3 || partition == PARTITION_VERT_3);
-    set_chroma_ref_info(mi_row, mi_col, index, bsize, &chroma_ref_info,
+    set_chroma_ref_info(xd->tree_type, mi_row, mi_col, index, bsize,
+                        &chroma_ref_info,
                         parent ? &parent->chroma_ref_info : NULL,
                         parent ? parent->bsize : BLOCK_INVALID,
                         parent ? parent->partition : PARTITION_NONE,
diff --git a/av1/encoder/context_tree.c b/av1/encoder/context_tree.c
index 4765f17..9d70358 100644
--- a/av1/encoder/context_tree.c
+++ b/av1/encoder/context_tree.c
@@ -82,8 +82,9 @@
   }
 }
 
-PICK_MODE_CONTEXT *av1_alloc_pmc(const AV1_COMMON *cm, int mi_row, int mi_col,
-                                 BLOCK_SIZE bsize, PC_TREE *parent,
+PICK_MODE_CONTEXT *av1_alloc_pmc(const AV1_COMMON *cm, TREE_TYPE tree_type,
+                                 int mi_row, int mi_col, BLOCK_SIZE bsize,
+                                 PC_TREE *parent,
                                  PARTITION_TYPE parent_partition, int index,
                                  int subsampling_x, int subsampling_y,
                                  PC_TREE_SHARED_BUFFERS *shared_bufs) {
@@ -94,7 +95,8 @@
   ctx->rd_mode_is_ready = 0;
   ctx->parent = parent;
   ctx->index = index;
-  set_chroma_ref_info(mi_row, mi_col, index, bsize, &ctx->chroma_ref_info,
+  set_chroma_ref_info(tree_type, mi_row, mi_col, index, bsize,
+                      &ctx->chroma_ref_info,
                       parent ? &parent->chroma_ref_info : NULL,
                       parent ? parent->block_size : BLOCK_INVALID,
                       parent_partition, subsampling_x, subsampling_y);
@@ -186,8 +188,8 @@
   aom_free(ctx);
 }
 
-PC_TREE *av1_alloc_pc_tree_node(int mi_row, int mi_col, BLOCK_SIZE bsize,
-                                PC_TREE *parent,
+PC_TREE *av1_alloc_pc_tree_node(TREE_TYPE tree_type, int mi_row, int mi_col,
+                                BLOCK_SIZE bsize, PC_TREE *parent,
                                 PARTITION_TYPE parent_partition, int index,
                                 int is_last, int subsampling_x,
                                 int subsampling_y) {
@@ -208,7 +210,8 @@
   av1_invalid_rd_stats(&pc_tree->none_rd);
   pc_tree->skippable = false;
 #endif  // CONFIG_EXT_RECUR_PARTITIONS
-  set_chroma_ref_info(mi_row, mi_col, index, bsize, &pc_tree->chroma_ref_info,
+  set_chroma_ref_info(tree_type, mi_row, mi_col, index, bsize,
+                      &pc_tree->chroma_ref_info,
                       parent ? &parent->chroma_ref_info : NULL,
                       parent ? parent->block_size : BLOCK_INVALID,
                       parent_partition, subsampling_x, subsampling_y);
@@ -395,7 +398,7 @@
 void av1_copy_pc_tree_recursive(const AV1_COMMON *cm, PC_TREE *dst,
                                 PC_TREE *src, int ss_x, int ss_y,
                                 PC_TREE_SHARED_BUFFERS *shared_bufs,
-                                int num_planes) {
+                                TREE_TYPE tree_type, int num_planes) {
   // Copy the best partition type. For basic information like bsize and index,
   // we assume they have been set properly when initializing the dst PC_TREE
   dst->partitioning = src->partitioning;
@@ -420,7 +423,7 @@
       if (dst->none) av1_free_pmc(dst->none, num_planes);
       dst->none = NULL;
       if (src->none) {
-        dst->none = av1_alloc_pmc(cm, mi_row, mi_col, bsize, dst,
+        dst->none = av1_alloc_pmc(cm, tree_type, mi_row, mi_col, bsize, dst,
                                   PARTITION_NONE, 0, ss_x, ss_y, shared_bufs);
         av1_copy_tree_context(dst->none, src->none);
       }
@@ -437,10 +440,11 @@
             const int x_idx = (i & 1) * (mi_size_wide[bsize] >> 1);
             const int y_idx = (i >> 1) * (mi_size_high[bsize] >> 1);
             dst->split[i] = av1_alloc_pc_tree_node(
-                mi_row + y_idx, mi_col + x_idx, subsize, dst, PARTITION_SPLIT,
-                i, i == 3, ss_x, ss_y);
+                tree_type, mi_row + y_idx, mi_col + x_idx, subsize, dst,
+                PARTITION_SPLIT, i, i == 3, ss_x, ss_y);
             av1_copy_pc_tree_recursive(cm, dst->split[i], src->split[i], ss_x,
-                                       ss_y, shared_bufs, num_planes);
+                                       ss_y, shared_bufs, tree_type,
+                                       num_planes);
           }
         }
       }
@@ -455,12 +459,12 @@
           }
           if (src->horizontal[i]) {
             const int this_mi_row = mi_row + i * (mi_size_high[bsize] >> 1);
-            dst->horizontal[i] =
-                av1_alloc_pc_tree_node(this_mi_row, mi_col, subsize, dst,
-                                       PARTITION_HORZ, i, i == 1, ss_x, ss_y);
+            dst->horizontal[i] = av1_alloc_pc_tree_node(
+                tree_type, this_mi_row, mi_col, subsize, dst, PARTITION_HORZ, i,
+                i == 1, ss_x, ss_y);
             av1_copy_pc_tree_recursive(cm, dst->horizontal[i],
                                        src->horizontal[i], ss_x, ss_y,
-                                       shared_bufs, num_planes);
+                                       shared_bufs, tree_type, num_planes);
           }
         }
       }
@@ -475,11 +479,12 @@
           }
           if (src->vertical[i]) {
             const int this_mi_col = mi_col + i * (mi_size_wide[bsize] >> 1);
-            dst->vertical[i] =
-                av1_alloc_pc_tree_node(mi_row, this_mi_col, subsize, dst,
-                                       PARTITION_VERT, i, i == 1, ss_x, ss_y);
+            dst->vertical[i] = av1_alloc_pc_tree_node(
+                tree_type, mi_row, this_mi_col, subsize, dst, PARTITION_VERT, i,
+                i == 1, ss_x, ss_y);
             av1_copy_pc_tree_recursive(cm, dst->vertical[i], src->vertical[i],
-                                       ss_x, ss_y, shared_bufs, num_planes);
+                                       ss_x, ss_y, shared_bufs, tree_type,
+                                       num_planes);
           }
         }
       }
@@ -505,11 +510,11 @@
           }
           if (src->horizontal4a[i]) {
             dst->horizontal4a[i] = av1_alloc_pc_tree_node(
-                mi_rows[i], mi_col, subsizes[i], dst, PARTITION_HORZ_4A, i,
-                i == 3, ss_x, ss_y);
+                tree_type, mi_rows[i], mi_col, subsizes[i], dst,
+                PARTITION_HORZ_4A, i, i == 3, ss_x, ss_y);
             av1_copy_pc_tree_recursive(cm, dst->horizontal4a[i],
                                        src->horizontal4a[i], ss_x, ss_y,
-                                       shared_bufs, num_planes);
+                                       shared_bufs, tree_type, num_planes);
           }
         }
       }
@@ -534,11 +539,11 @@
           }
           if (src->horizontal4b[i]) {
             dst->horizontal4b[i] = av1_alloc_pc_tree_node(
-                mi_rows[i], mi_col, subsizes[i], dst, PARTITION_HORZ_4B, i,
-                i == 3, ss_x, ss_y);
+                tree_type, mi_rows[i], mi_col, subsizes[i], dst,
+                PARTITION_HORZ_4B, i, i == 3, ss_x, ss_y);
             av1_copy_pc_tree_recursive(cm, dst->horizontal4b[i],
                                        src->horizontal4b[i], ss_x, ss_y,
-                                       shared_bufs, num_planes);
+                                       shared_bufs, tree_type, num_planes);
           }
         }
       }
@@ -563,11 +568,11 @@
           }
           if (src->vertical4a[i]) {
             dst->vertical4a[i] = av1_alloc_pc_tree_node(
-                mi_row, mi_cols[i], subsizes[i], dst, PARTITION_VERT_4A, i,
-                i == 3, ss_x, ss_y);
+                tree_type, mi_row, mi_cols[i], subsizes[i], dst,
+                PARTITION_VERT_4A, i, i == 3, ss_x, ss_y);
             av1_copy_pc_tree_recursive(cm, dst->vertical4a[i],
                                        src->vertical4a[i], ss_x, ss_y,
-                                       shared_bufs, num_planes);
+                                       shared_bufs, tree_type, num_planes);
           }
         }
       }
@@ -592,11 +597,11 @@
           }
           if (src->vertical4b[i]) {
             dst->vertical4b[i] = av1_alloc_pc_tree_node(
-                mi_row, mi_cols[i], subsizes[i], dst, PARTITION_VERT_4B, i,
-                i == 3, ss_x, ss_y);
+                tree_type, mi_row, mi_cols[i], subsizes[i], dst,
+                PARTITION_VERT_4B, i, i == 3, ss_x, ss_y);
             av1_copy_pc_tree_recursive(cm, dst->vertical4b[i],
                                        src->vertical4b[i], ss_x, ss_y,
-                                       shared_bufs, num_planes);
+                                       shared_bufs, tree_type, num_planes);
           }
         }
       }
@@ -621,11 +626,11 @@
           }
           if (src->horizontal3[i]) {
             dst->horizontal3[i] = av1_alloc_pc_tree_node(
-                mi_row + offset_mr, mi_col + offset_mc, this_subsize, dst,
-                PARTITION_HORZ_3, i, i == 3, ss_x, ss_y);
+                tree_type, mi_row + offset_mr, mi_col + offset_mc, this_subsize,
+                dst, PARTITION_HORZ_3, i, i == 3, ss_x, ss_y);
             av1_copy_pc_tree_recursive(cm, dst->horizontal3[i],
                                        src->horizontal3[i], ss_x, ss_y,
-                                       shared_bufs, num_planes);
+                                       shared_bufs, tree_type, num_planes);
           }
         }
       }
@@ -647,10 +652,11 @@
           }
           if (src->vertical3[i]) {
             dst->vertical3[i] = av1_alloc_pc_tree_node(
-                mi_row + offset_mr, mi_col + offset_mc, this_subsize, dst,
-                PARTITION_VERT_3, i, i == 3, ss_x, ss_y);
+                tree_type, mi_row + offset_mr, mi_col + offset_mc, this_subsize,
+                dst, PARTITION_VERT_3, i, i == 3, ss_x, ss_y);
             av1_copy_pc_tree_recursive(cm, dst->vertical3[i], src->vertical3[i],
-                                       ss_x, ss_y, shared_bufs, num_planes);
+                                       ss_x, ss_y, shared_bufs, tree_type,
+                                       num_planes);
           }
         }
       }
@@ -673,12 +679,12 @@
             dst->horizontal3[i] = NULL;
           }
           if (src->horizontal3[i]) {
-            dst->horizontal3[i] =
-                av1_alloc_pc_tree_node(mi_rows[i], mi_col, subsizes[i], dst,
-                                       PARTITION_HORZ_3, i, i == 2, ss_x, ss_y);
+            dst->horizontal3[i] = av1_alloc_pc_tree_node(
+                tree_type, mi_rows[i], mi_col, subsizes[i], dst,
+                PARTITION_HORZ_3, i, i == 2, ss_x, ss_y);
             av1_copy_pc_tree_recursive(cm, dst->horizontal3[i],
                                        src->horizontal3[i], ss_x, ss_y,
-                                       shared_bufs, num_planes);
+                                       shared_bufs, tree_type, num_planes);
           }
         }
       }
@@ -698,11 +704,12 @@
             dst->vertical3[i] = NULL;
           }
           if (src->vertical3[i]) {
-            dst->vertical3[i] =
-                av1_alloc_pc_tree_node(mi_row, mi_cols[i], subsizes[i], dst,
-                                       PARTITION_VERT_3, i, i == 2, ss_x, ss_y);
+            dst->vertical3[i] = av1_alloc_pc_tree_node(
+                tree_type, mi_row, mi_cols[i], subsizes[i], dst,
+                PARTITION_VERT_3, i, i == 2, ss_x, ss_y);
             av1_copy_pc_tree_recursive(cm, dst->vertical3[i], src->vertical3[i],
-                                       ss_x, ss_y, shared_bufs, num_planes);
+                                       ss_x, ss_y, shared_bufs, tree_type,
+                                       num_planes);
           }
         }
       }
diff --git a/av1/encoder/context_tree.h b/av1/encoder/context_tree.h
index 4631f84..0def611 100644
--- a/av1/encoder/context_tree.h
+++ b/av1/encoder/context_tree.h
@@ -145,8 +145,8 @@
                                    PC_TREE_SHARED_BUFFERS *shared_bufs);
 void av1_free_shared_coeff_buffer(PC_TREE_SHARED_BUFFERS *shared_bufs);
 
-PC_TREE *av1_alloc_pc_tree_node(int mi_row, int mi_col, BLOCK_SIZE bsize,
-                                PC_TREE *parent,
+PC_TREE *av1_alloc_pc_tree_node(TREE_TYPE tree_type, int mi_row, int mi_col,
+                                BLOCK_SIZE bsize, PC_TREE *parent,
                                 PARTITION_TYPE parent_partition, int index,
                                 int is_last, int subsampling_x,
                                 int subsampling_y);
@@ -156,11 +156,12 @@
 void av1_copy_pc_tree_recursive(const AV1_COMMON *cm, PC_TREE *dst,
                                 PC_TREE *src, int ss_x, int ss_y,
                                 PC_TREE_SHARED_BUFFERS *shared_bufs,
-                                int num_planes);
+                                TREE_TYPE tree_type, int num_planes);
 #endif  // CONFIG_EXT_RECUR_PARTITIONS
 
-PICK_MODE_CONTEXT *av1_alloc_pmc(const AV1_COMMON *cm, int mi_row, int mi_col,
-                                 BLOCK_SIZE bsize, PC_TREE *parent,
+PICK_MODE_CONTEXT *av1_alloc_pmc(const AV1_COMMON *cm, TREE_TYPE tree_type,
+                                 int mi_row, int mi_col, BLOCK_SIZE bsize,
+                                 PC_TREE *parent,
                                  PARTITION_TYPE parent_partition, int index,
                                  int subsampling_x, int subsampling_y,
                                  PC_TREE_SHARED_BUFFERS *shared_bufs);
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index af23cd2..ca7d7bc 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -573,8 +573,9 @@
                              : (loop_idx == 0 ? LUMA_PART : CHROMA_PART));
     init_encode_rd_sb(cpi, td, tile_data, sms_root, &dummy_rdc, mi_row, mi_col,
                       1);
-    PC_TREE *const pc_root = av1_alloc_pc_tree_node(
-        mi_row, mi_col, sb_size, NULL, PARTITION_NONE, 0, 1, ss_x, ss_y);
+    PC_TREE *const pc_root =
+        av1_alloc_pc_tree_node(xd->tree_type, mi_row, mi_col, sb_size, NULL,
+                               PARTITION_NONE, 0, 1, ss_x, ss_y);
 #if CONFIG_EXT_RECUR_PARTITIONS
     const PARTITION_TREE *template_tree =
         multi_pass_params ? multi_pass_params->template_tree : NULL;
@@ -804,8 +805,9 @@
           cm, xd->tree_type, mi_row, mi_col, bsize,
           xd->sbi->ptree_root[av1_get_sdp_idx(xd->tree_type)]);
 #endif  // CONFIG_EXT_RECUR_PARTITIONS
-      PC_TREE *const pc_root = av1_alloc_pc_tree_node(
-          mi_row, mi_col, sb_size, NULL, PARTITION_NONE, 0, 1, ss_x, ss_y);
+      PC_TREE *const pc_root =
+          av1_alloc_pc_tree_node(xd->tree_type, mi_row, mi_col, sb_size, NULL,
+                                 PARTITION_NONE, 0, 1, ss_x, ss_y);
       av1_rd_use_partition(cpi, td, tile_data, mi, tp, mi_row, mi_col, sb_size,
                            &dummy_rate, &dummy_dist, 1,
 #if CONFIG_EXT_RECUR_PARTITIONS
@@ -832,8 +834,9 @@
                                : (loop_idx == 0 ? LUMA_PART : CHROMA_PART));
       init_encode_rd_sb(cpi, td, tile_data, sms_root, &dummy_rdc, mi_row,
                         mi_col, 1);
-      PC_TREE *const pc_root = av1_alloc_pc_tree_node(
-          mi_row, mi_col, sb_size, NULL, PARTITION_NONE, 0, 1, ss_x, ss_y);
+      PC_TREE *const pc_root =
+          av1_alloc_pc_tree_node(xd->tree_type, mi_row, mi_col, sb_size, NULL,
+                                 PARTITION_NONE, 0, 1, ss_x, ss_y);
 #if CONFIG_EXT_RECUR_PARTITIONS
       av1_reset_ptree_in_sbi(xd->sbi, xd->tree_type);
       av1_build_partition_tree_fixed_partitioning(
diff --git a/av1/encoder/encoder_alloc.h b/av1/encoder/encoder_alloc.h
index d167ce3..d40a3a8 100644
--- a/av1/encoder/encoder_alloc.h
+++ b/av1/encoder/encoder_alloc.h
@@ -80,7 +80,7 @@
 #endif  // CONFIG_EXT_RECUR_PARTITIONS
 
   cpi->td.firstpass_ctx =
-      av1_alloc_pmc(cm, 0, 0, BLOCK_16X16, NULL, PARTITION_NONE, 0,
+      av1_alloc_pmc(cm, SHARED_PART, 0, 0, BLOCK_16X16, NULL, PARTITION_NONE, 0,
                     cm->seq_params.subsampling_x, cm->seq_params.subsampling_y,
                     &cpi->td.shared_coeff_buf);
 }
diff --git a/av1/encoder/ethread.c b/av1/encoder/ethread.c
index df56a28..a5c4482 100644
--- a/av1/encoder/ethread.c
+++ b/av1/encoder/ethread.c
@@ -675,7 +675,7 @@
     if (i > 0) {
       // Set up firstpass PICK_MODE_CONTEXT.
       thread_data->td->firstpass_ctx = av1_alloc_pmc(
-          cm, 0, 0, BLOCK_16X16, NULL, PARTITION_NONE, 0,
+          cm, SHARED_PART, 0, 0, BLOCK_16X16, NULL, PARTITION_NONE, 0,
           cm->seq_params.subsampling_x, cm->seq_params.subsampling_y,
           &thread_data->td->shared_coeff_buf);
 
diff --git a/av1/encoder/partition_search.c b/av1/encoder/partition_search.c
index e148501..7a20891 100644
--- a/av1/encoder/partition_search.c
+++ b/av1/encoder/partition_search.c
@@ -2442,8 +2442,8 @@
     const int ss_x = xd->plane[1].subsampling_x;
     const int ss_y = xd->plane[1].subsampling_y;
     set_chroma_ref_info(
-        mi_row, mi_col, ptree->index, bsize, &ptree->chroma_ref_info,
-        parent ? &parent->chroma_ref_info : NULL,
+        xd->tree_type, mi_row, mi_col, ptree->index, bsize,
+        &ptree->chroma_ref_info, parent ? &parent->chroma_ref_info : NULL,
         parent ? parent->bsize : BLOCK_INVALID,
         parent ? parent->partition : PARTITION_NONE, ss_x, ss_y);
 
@@ -2764,7 +2764,7 @@
   const int ss_y = cm->seq_params.subsampling_y;
 
   PARTITION_TREE *parent = ptree->parent;
-  set_chroma_ref_info(mi_row, mi_col, ptree->index, bsize,
+  set_chroma_ref_info(tree_type, mi_row, mi_col, ptree->index, bsize,
                       &ptree->chroma_ref_info,
                       parent ? &parent->chroma_ref_info : NULL,
                       parent ? parent->bsize : BLOCK_INVALID,
@@ -2984,8 +2984,8 @@
 
   if (pc_tree->none == NULL) {
     pc_tree->none =
-        av1_alloc_pmc(cm, mi_row, mi_col, bsize, pc_tree, PARTITION_NONE, 0,
-                      ss_x, ss_y, &td->shared_coeff_buf);
+        av1_alloc_pmc(cm, xd->tree_type, mi_row, mi_col, bsize, pc_tree,
+                      PARTITION_NONE, 0, ss_x, ss_y, &td->shared_coeff_buf);
   }
   PICK_MODE_CONTEXT *ctx_none = pc_tree->none;
 
@@ -3022,9 +3022,9 @@
   for (int i = 0; i < SUB_PARTITIONS_SPLIT; ++i) {
     int x_idx = (i & 1) * hbs;
     int y_idx = (i >> 1) * hbs;
-    pc_tree->split[i] =
-        av1_alloc_pc_tree_node(mi_row + y_idx, mi_col + x_idx, split_subsize,
-                               pc_tree, PARTITION_SPLIT, i, i == 3, ss_x, ss_y);
+    pc_tree->split[i] = av1_alloc_pc_tree_node(
+        xd->tree_type, mi_row + y_idx, mi_col + x_idx, split_subsize, pc_tree,
+        PARTITION_SPLIT, i, i == 3, ss_x, ss_y);
   }
 #endif  // !CONFIG_EXT_RECUR_PARTITIONS
   switch (partition) {
@@ -3034,11 +3034,12 @@
       break;
     case PARTITION_HORZ:
 #if CONFIG_EXT_RECUR_PARTITIONS
-      pc_tree->horizontal[0] = av1_alloc_pc_tree_node(
-          mi_row, mi_col, subsize, pc_tree, PARTITION_HORZ, 0, 0, ss_x, ss_y);
+      pc_tree->horizontal[0] =
+          av1_alloc_pc_tree_node(xd->tree_type, mi_row, mi_col, subsize,
+                                 pc_tree, PARTITION_HORZ, 0, 0, ss_x, ss_y);
       pc_tree->horizontal[1] =
-          av1_alloc_pc_tree_node(mi_row + hbh, mi_col, subsize, pc_tree,
-                                 PARTITION_HORZ, 1, 1, ss_x, ss_y);
+          av1_alloc_pc_tree_node(xd->tree_type, mi_row + hbh, mi_col, subsize,
+                                 pc_tree, PARTITION_HORZ, 1, 1, ss_x, ss_y);
       av1_rd_use_partition(cpi, td, tile_data, mib, tp, mi_row, mi_col, subsize,
                            &last_part_rdc.rate, &last_part_rdc.dist, 1,
                            ptree ? ptree->sub_tree[0] : NULL,
@@ -3047,8 +3048,8 @@
       for (int i = 0; i < SUB_PARTITIONS_RECT; ++i) {
         if (pc_tree->horizontal[i] == NULL) {
           pc_tree->horizontal[i] = av1_alloc_pmc(
-              cm, mi_row + hbs * i, mi_col, subsize, pc_tree, PARTITION_HORZ, i,
-              ss_x, ss_y, &td->shared_coeff_buf);
+              cm, xd->tree_type, mi_row + hbs * i, mi_col, subsize, pc_tree,
+              PARTITION_HORZ, i, ss_x, ss_y, &td->shared_coeff_buf);
         }
       }
       pick_sb_modes(cpi, tile_data, x, mi_row, mi_col, &last_part_rdc,
@@ -3084,11 +3085,12 @@
       break;
     case PARTITION_VERT:
 #if CONFIG_EXT_RECUR_PARTITIONS
-      pc_tree->vertical[0] = av1_alloc_pc_tree_node(
-          mi_row, mi_col, subsize, pc_tree, PARTITION_VERT, 0, 0, ss_x, ss_y);
+      pc_tree->vertical[0] =
+          av1_alloc_pc_tree_node(xd->tree_type, mi_row, mi_col, subsize,
+                                 pc_tree, PARTITION_VERT, 0, 0, ss_x, ss_y);
       pc_tree->vertical[1] =
-          av1_alloc_pc_tree_node(mi_row, mi_col + hbw, subsize, pc_tree,
-                                 PARTITION_VERT, 1, 1, ss_x, ss_y);
+          av1_alloc_pc_tree_node(xd->tree_type, mi_row, mi_col + hbw, subsize,
+                                 pc_tree, PARTITION_VERT, 1, 1, ss_x, ss_y);
       av1_rd_use_partition(cpi, td, tile_data, mib, tp, mi_row, mi_col, subsize,
                            &last_part_rdc.rate, &last_part_rdc.dist, 1,
                            ptree ? ptree->sub_tree[0] : NULL,
@@ -3097,8 +3099,8 @@
       for (int i = 0; i < SUB_PARTITIONS_RECT; ++i) {
         if (pc_tree->vertical[i] == NULL) {
           pc_tree->vertical[i] = av1_alloc_pmc(
-              cm, mi_row, mi_col + hbs * i, subsize, pc_tree, PARTITION_VERT, i,
-              ss_x, ss_y, &td->shared_coeff_buf);
+              cm, xd->tree_type, mi_row, mi_col + hbs * i, subsize, pc_tree,
+              PARTITION_VERT, i, ss_x, ss_y, &td->shared_coeff_buf);
         }
       }
       pick_sb_modes(cpi, tile_data, x, mi_row, mi_col, &last_part_rdc,
@@ -4070,11 +4072,11 @@
       }
     }
     sub_tree[0] = av1_alloc_pc_tree_node(
-        mi_pos_rect[i][0][0], mi_pos_rect[i][0][1], blk_params.subsize, pc_tree,
-        partition_type, 0, 0, ss_x, ss_y);
+        xd->tree_type, mi_pos_rect[i][0][0], mi_pos_rect[i][0][1],
+        blk_params.subsize, pc_tree, partition_type, 0, 0, ss_x, ss_y);
     sub_tree[1] = av1_alloc_pc_tree_node(
-        mi_pos_rect[i][1][0], mi_pos_rect[i][1][1], blk_params.subsize, pc_tree,
-        partition_type, 1, 1, ss_x, ss_y);
+        xd->tree_type, mi_pos_rect[i][1][0], mi_pos_rect[i][1][1],
+        blk_params.subsize, pc_tree, partition_type, 1, 1, ss_x, ss_y);
 
     bool both_blocks_skippable = true;
     const bool track_ptree_luma =
@@ -4090,10 +4092,11 @@
     for (int j = 0; j < SUB_PARTITIONS_RECT; j++) {
       assert(cur_ctx[i][j] != NULL);
       if (cur_ctx[i][j][0] == NULL) {
-        cur_ctx[i][j][0] = av1_alloc_pmc(
-            cm, mi_pos_rect[i][j][0], mi_pos_rect[i][j][1], blk_params.subsize,
-            pc_tree, partition_type, j, part_search_state->ss_x,
-            part_search_state->ss_y, &td->shared_coeff_buf);
+        cur_ctx[i][j][0] =
+            av1_alloc_pmc(cm, xd->tree_type, mi_pos_rect[i][j][0],
+                          mi_pos_rect[i][j][1], blk_params.subsize, pc_tree,
+                          partition_type, j, part_search_state->ss_x,
+                          part_search_state->ss_y, &td->shared_coeff_buf);
       }
     }
     sum_rdc->rate = part_search_state->partition_cost[partition_type];
@@ -4369,9 +4372,9 @@
       // Set AB partition context.
       if (cur_part_ctxs[ab_part_type][i] == NULL)
         cur_part_ctxs[ab_part_type][i] = av1_alloc_pmc(
-            cm, ab_mi_pos[ab_part_type][i][0], ab_mi_pos[ab_part_type][i][1],
-            ab_subsize[ab_part_type][i], pc_tree, part_type, i,
-            part_search_state->ss_x, part_search_state->ss_y,
+            cm, xd->tree_type, ab_mi_pos[ab_part_type][i][0],
+            ab_mi_pos[ab_part_type][i][1], ab_subsize[ab_part_type][i], pc_tree,
+            part_type, i, part_search_state->ss_x, part_search_state->ss_y,
             &td->shared_coeff_buf);
       // Set mode as not ready.
       cur_part_ctxs[ab_part_type][i]->rd_mode_is_ready = 0;
@@ -4432,8 +4435,8 @@
   for (PART4_TYPES i = 0; i < SUB_PARTITIONS_PART4; ++i) {
     if (cur_part_ctx[i] == NULL)
       cur_part_ctx[i] =
-          av1_alloc_pmc(cm, mi_pos[i][0], mi_pos[i][1], subsize, pc_tree,
-                        partition_type, i, part_search_state->ss_x,
+          av1_alloc_pmc(cm, xd->tree_type, mi_pos[i][0], mi_pos[i][1], subsize,
+                        pc_tree, partition_type, i, part_search_state->ss_x,
                         part_search_state->ss_y, &td->shared_coeff_buf);
   }
 }
@@ -4650,9 +4653,9 @@
   // Set PARTITION_NONE context.
   if (pc_tree->none == NULL)
     pc_tree->none = av1_alloc_pmc(
-        cm, blk_params.mi_row, blk_params.mi_col, blk_params.bsize, pc_tree,
-        PARTITION_NONE, 0, part_search_state->ss_x, part_search_state->ss_y,
-        &td->shared_coeff_buf);
+        cm, x->e_mbd.tree_type, blk_params.mi_row, blk_params.mi_col,
+        blk_params.bsize, pc_tree, PARTITION_NONE, 0, part_search_state->ss_x,
+        part_search_state->ss_y, &td->shared_coeff_buf);
 
   // Set PARTITION_NONE type cost.
   if (part_search_state->partition_none_allowed) {
@@ -4995,8 +4998,9 @@
 
     if (pc_tree->split[idx] == NULL) {
       pc_tree->split[idx] = av1_alloc_pc_tree_node(
-          mi_row + y_idx, mi_col + x_idx, subsize, pc_tree, PARTITION_SPLIT,
-          idx, idx == 3, part_search_state->ss_x, part_search_state->ss_y);
+          x->e_ebd.tree_type, mi_row + y_idx, mi_col + x_idx, subsize, pc_tree,
+          PARTITION_SPLIT, idx, idx == 3, part_search_state->ss_x,
+          part_search_state->ss_y);
     }
     int64_t *p_split_rd = &part_search_state->split_rd[idx];
     RD_STATS best_remain_rdcost;
@@ -5831,8 +5835,8 @@
     }
     const int this_mi_row = mi_row + eighth_step * cum_step_multipliers[idx];
     pc_tree->horizontal4a[idx] = av1_alloc_pc_tree_node(
-        this_mi_row, mi_col, subblock_sizes[idx], pc_tree, PARTITION_HORZ_4A,
-        idx, idx == 3, ss_x, ss_y);
+        xd->tree_type, this_mi_row, mi_col, subblock_sizes[idx], pc_tree,
+        PARTITION_HORZ_4A, idx, idx == 3, ss_x, ss_y);
   }
 
   bool skippable = true;
@@ -5941,8 +5945,8 @@
     }
     const int this_mi_row = mi_row + eighth_step * cum_step_multipliers[idx];
     pc_tree->horizontal4b[idx] = av1_alloc_pc_tree_node(
-        this_mi_row, mi_col, subblock_sizes[idx], pc_tree, PARTITION_HORZ_4B,
-        idx, idx == 3, ss_x, ss_y);
+        xd->tree_type, this_mi_row, mi_col, subblock_sizes[idx], pc_tree,
+        PARTITION_HORZ_4B, idx, idx == 3, ss_x, ss_y);
   }
 
   bool skippable = true;
@@ -6051,8 +6055,8 @@
     }
     const int this_mi_col = mi_col + eighth_step * cum_step_multipliers[idx];
     pc_tree->vertical4a[idx] = av1_alloc_pc_tree_node(
-        mi_row, this_mi_col, subblock_sizes[idx], pc_tree, PARTITION_VERT_4A,
-        idx, idx == 3, ss_x, ss_y);
+        xd->tree_type, mi_row, this_mi_col, subblock_sizes[idx], pc_tree,
+        PARTITION_VERT_4A, idx, idx == 3, ss_x, ss_y);
   }
 
   bool skippable = true;
@@ -6161,8 +6165,8 @@
     }
     const int this_mi_col = mi_col + eighth_step * cum_step_multipliers[idx];
     pc_tree->vertical4b[idx] = av1_alloc_pc_tree_node(
-        mi_row, this_mi_col, subblock_sizes[idx], pc_tree, PARTITION_VERT_4B,
-        idx, idx == 3, ss_x, ss_y);
+        xd->tree_type, mi_row, this_mi_col, subblock_sizes[idx], pc_tree,
+        PARTITION_VERT_4B, idx, idx == 3, ss_x, ss_y);
   }
 
   bool skippable = true;
@@ -6277,8 +6281,9 @@
     }
 
     pc_tree->horizontal3[idx] = av1_alloc_pc_tree_node(
-        mi_row + offset_mr[idx], mi_col + offset_mc[idx], subblock_sizes[idx],
-        pc_tree, PARTITION_HORZ_3, idx, idx == 3, ss_x, ss_y);
+        xd->tree_type, mi_row + offset_mr[idx], mi_col + offset_mc[idx],
+        subblock_sizes[idx], pc_tree, PARTITION_HORZ_3, idx, idx == 3, ss_x,
+        ss_y);
   }
 #else   // CONFIG_H_PARTITION
   const BLOCK_SIZE sml_subsize = get_partition_subsize(bsize, PARTITION_HORZ_3);
@@ -6294,14 +6299,14 @@
     }
   }
   pc_tree->horizontal3[0] =
-      av1_alloc_pc_tree_node(mi_row, mi_col, subblock_sizes[0], pc_tree,
-                             PARTITION_HORZ_3, 0, 0, ss_x, ss_y);
-  pc_tree->horizontal3[1] =
-      av1_alloc_pc_tree_node(mi_row + quarter_step, mi_col, subblock_sizes[1],
-                             pc_tree, PARTITION_HORZ_3, 1, 0, ss_x, ss_y);
+      av1_alloc_pc_tree_node(xd->tree_type, mi_row, mi_col, subblock_sizes[0],
+                             pc_tree, PARTITION_HORZ_3, 0, 0, ss_x, ss_y);
+  pc_tree->horizontal3[1] = av1_alloc_pc_tree_node(
+      xd->tree_type, mi_row + quarter_step, mi_col, subblock_sizes[1], pc_tree,
+      PARTITION_HORZ_3, 1, 0, ss_x, ss_y);
   pc_tree->horizontal3[2] = av1_alloc_pc_tree_node(
-      mi_row + quarter_step * 3, mi_col, subblock_sizes[2], pc_tree,
-      PARTITION_HORZ_3, 2, 1, ss_x, ss_y);
+      xd->tree_type, mi_row + quarter_step * 3, mi_col, subblock_sizes[2],
+      pc_tree, PARTITION_HORZ_3, 2, 1, ss_x, ss_y);
 #endif  // CONFIG_H_PARTITION
 
   bool skippable = true;
@@ -6427,8 +6432,9 @@
     }
 
     pc_tree->vertical3[idx] = av1_alloc_pc_tree_node(
-        mi_row + offset_mr[idx], mi_col + offset_mc[idx], subblock_sizes[idx],
-        pc_tree, PARTITION_VERT_3, idx, idx == 3, ss_x, ss_y);
+        xd->tree_type, mi_row + offset_mr[idx], mi_col + offset_mc[idx],
+        subblock_sizes[idx], pc_tree, PARTITION_VERT_3, idx, idx == 3, ss_x,
+        ss_y);
   }
 #else
   const BLOCK_SIZE sml_subsize = get_partition_subsize(bsize, PARTITION_VERT_3);
@@ -6444,14 +6450,14 @@
     }
   }
   pc_tree->vertical3[0] =
-      av1_alloc_pc_tree_node(mi_row, mi_col, subblock_sizes[0], pc_tree,
-                             PARTITION_VERT_3, 0, 0, ss_x, ss_y);
-  pc_tree->vertical3[1] =
-      av1_alloc_pc_tree_node(mi_row, mi_col + quarter_step, subblock_sizes[1],
-                             pc_tree, PARTITION_VERT_3, 1, 0, ss_x, ss_y);
+      av1_alloc_pc_tree_node(xd->tree_type, mi_row, mi_col, subblock_sizes[0],
+                             pc_tree, PARTITION_VERT_3, 0, 0, ss_x, ss_y);
+  pc_tree->vertical3[1] = av1_alloc_pc_tree_node(
+      xd->tree_type, mi_row, mi_col + quarter_step, subblock_sizes[1], pc_tree,
+      PARTITION_VERT_3, 1, 0, ss_x, ss_y);
   pc_tree->vertical3[2] = av1_alloc_pc_tree_node(
-      mi_row, mi_col + quarter_step * 3, subblock_sizes[2], pc_tree,
-      PARTITION_VERT_3, 2, 1, ss_x, ss_y);
+      xd->tree_type, mi_row, mi_col + quarter_step * 3, subblock_sizes[2],
+      pc_tree, PARTITION_VERT_3, 2, 1, ss_x, ss_y);
 #endif  // CONFIG_H_PARTITION
 
   bool skippable = true;
@@ -6732,7 +6738,8 @@
     if (counterpart_block->rd_cost.rate != INT_MAX) {
       av1_copy_pc_tree_recursive(cm, pc_tree, counterpart_block,
                                  part_search_state.ss_x, part_search_state.ss_y,
-                                 &td->shared_coeff_buf, num_planes);
+                                 &td->shared_coeff_buf, xd->tree_type,
+                                 num_planes);
       *rd_cost = pc_tree->rd_cost;
 #if CONFIG_C043_MVP_IMPROVEMENTS
       x->e_mbd.ref_mv_bank = counterpart_block->ref_mv_bank;