Cross chroma transform for intra and misc optimizations
diff --git a/aom_dsp/txfm_common.h b/aom_dsp/txfm_common.h
index 47a276e..42b12e6 100644
--- a/aom_dsp/txfm_common.h
+++ b/aom_dsp/txfm_common.h
@@ -33,6 +33,9 @@
   // intra prediction mode used for the current tx block
   PREDICTION_MODE intra_mode;
 #endif
+#if CONFIG_CROSS_CHROMA_TX
+  CctxType cctx_type;
+#endif  // CONFIG_CROSS_CHROMA_TX
   TX_SIZE tx_size;
   int lossless;
   int bd;
diff --git a/av1/common/alloccommon.c b/av1/common/alloccommon.c
index c851289..ed356a8 100644
--- a/av1/common/alloccommon.c
+++ b/av1/common/alloccommon.c
@@ -239,6 +239,11 @@
     mi_params->tx_type_map =
         aom_calloc(mi_grid_size, sizeof(*mi_params->tx_type_map));
     if (!mi_params->tx_type_map) return 1;
+#if CONFIG_CROSS_CHROMA_TX
+    mi_params->cctx_type_map =
+        aom_calloc(mi_grid_size, sizeof(*mi_params->cctx_type_map));
+    if (!mi_params->cctx_type_map) return 1;
+#endif  // CONFIG_CROSS_CHROMA_TX
   }
 
   return 0;
diff --git a/av1/common/av1_common_int.h b/av1/common/av1_common_int.h
index b2cfe6f..06123ba 100644
--- a/av1/common/av1_common_int.h
+++ b/av1/common/av1_common_int.h
@@ -764,6 +764,15 @@
    * primary tx_type
    */
   TX_TYPE *tx_type_map;
+#if CONFIG_CROSS_CHROMA_TX
+  /*!
+   * An array of cctx types for each 4x4 block in the frame.
+   * Number of allocated elements is same as 'mi_grid_size', and stride is
+   * same as 'mi_grid_size'. So, indexing into 'tx_type_map' is same as that of
+   * 'mi_grid_base'.
+   */
+  CctxType *cctx_type_map;
+#endif  // CONFIG_CROSS_CHROMA_TX
 
   /**
    * \name Function pointers to allow separate logic for encoder and decoder.
@@ -2219,8 +2228,12 @@
   // 'xd->mi' should point to an offset in 'mi_grid_base';
   xd->mi = mi_params->mi_grid_base + mi_grid_idx;
   // 'xd->tx_type_map' should point to an offset in 'mi_params->tx_type_map'.
-  if (xd->tree_type != CHROMA_PART)
+  if (xd->tree_type != CHROMA_PART) {
     xd->tx_type_map = mi_params->tx_type_map + mi_grid_idx;
+#if CONFIG_CROSS_CHROMA_TX
+    xd->cctx_type_map = mi_params->cctx_type_map + mi_grid_idx;
+#endif  // CONFIG_CROSS_CHROMA_TX
+  }
   xd->tx_type_map_stride = mi_params->mi_stride;
 }
 
@@ -2238,6 +2251,9 @@
   xd->mi[mi_params->mi_stride * blk_row + blk_col] =
       mi_params->mi_grid_base[mi_grid_idx];
   xd->tx_type_map = mi_params->tx_type_map + mi_grid_idx;
+#if CONFIG_CROSS_CHROMA_TX
+  xd->cctx_type_map = mi_params->cctx_type_map + mi_grid_idx;
+#endif  // CONFIG_CROSS_CHROMA_TX
   xd->tx_type_map_stride = mi_params->mi_stride;
 }
 // Return the number of sub-blocks whose width and height are
diff --git a/av1/common/av1_txfm.c b/av1/common/av1_txfm.c
index a38b11d..267c59e 100644
--- a/av1/common/av1_txfm.c
+++ b/av1/common/av1_txfm.c
@@ -197,8 +197,61 @@
 #endif  // CONFIG_DST_32X32
 
 #if CONFIG_CROSS_CHROMA_TX
-// Haar transform [1, 1; 1, -1] * 1/sqrt(2) * (1<<CCTX_PREC_BITS)
-const int32_t cctx_mtx[4] = { 181, 181, 181, -181 };
+// Given a rotation angle t, the CCTX transform matrix is defined as
+// [cos(t), sin(t); -sin(t), cos(t)] * 1<<CCTX_PREC_BITS). The array below only
+// stores two values: cos(t) and sin(t) for each rotation angle.
+const int32_t cctx_mtx[CCTX_TYPES - 1][2] = {
+#if CCTX_ANGLE_CONFIG == 0
+#if CCTX_POS_ANGLES
+  { 181, 181 },  // t = 45 degrees
+  { 222, 128 },  // t = 30 degrees
+  { 128, 222 },  // t = 60 degrees
+#endif           // CCTX_POS_ANGLES
+#if CCTX_NEG_ANGLES
+  { 181, -181 },  // t = -45 degrees
+  { 222, -128 },  // t = -30 degrees
+  { 128, -222 },  // t = -60 degrees
+#endif            // CCTX_NEG_ANGLES
+#elif CCTX_ANGLE_CONFIG == 1
+#if CCTX_POS_ANGLES
+  { 181, 181 },   // t = 45 degrees
+  { 236, 98 },    // t = 22.5 degrees
+  { 98, 236 },    // t = 67.5 degrees
+#endif  // CCTX_POS_ANGLES
+#if CCTX_NEG_ANGLES
+  { 181, -181 },  // t = -45 degrees
+  { 236, -98 },   // t = -22.5 degrees
+  { 98, -236 },   // t = -67.5 degrees
+#endif  // CCTX_NEG_ANGLES
+#elif CCTX_ANGLE_CONFIG == 2
+#if CCTX_POS_ANGLES
+  { 181, 181 },   // t = 45 degrees
+  { 232, 108 },   // t = 25 degrees
+  { 108, 232 },   // t = 65 degrees
+#endif  // CCTX_POS_ANGLES
+#if CCTX_NEG_ANGLES
+  { 181, -181 },  // t = -45 degrees
+  { 232, -108 },  // t = -25 degrees
+  { 108, -232 },  // t = -65 degrees
+#endif  // CCTX_NEG_ANGLES
+#elif CCTX_ANGLE_CONFIG == 3
+#if CCTX_POS_ANGLES
+  { 222, 128 },   // t = 30 degrees
+  { 128, 222 },   // t = 60 degrees
+#endif  // CCTX_POS_ANGLES
+#if CCTX_NEG_ANGLES
+  { 222, -128 },  // t = -30 degrees
+  { 128, -222 },  // t = -60 degrees
+#endif  // CCTX_NEG_ANGLES
+#elif CCTX_ANGLE_CONFIG == 4
+#if CCTX_POS_ANGLES
+  { 181, 181 },   // t = 45 degrees
+#endif  // CCTX_POS_ANGLES
+#if CCTX_NEG_ANGLES
+  { 181, -181 },  // t = -45 degrees
+#endif  // CCTX_NEG_ANGLES
+#endif
+};
 #endif  // CONFIG_CROSS_CHROMA_TX
 
 // av1_sinpi_arr_data[i][j] = (int)round((sqrt(2) * sin(j*Pi/9) * 2 / 3) * (1
diff --git a/av1/common/av1_txfm.h b/av1/common/av1_txfm.h
index f68cf5e..7cbb9e9 100644
--- a/av1/common/av1_txfm.h
+++ b/av1/common/av1_txfm.h
@@ -50,9 +50,11 @@
 #endif  // CONFIG_DDT_INTER
 
 #if CONFIG_CROSS_CHROMA_TX
+#define CCTX_INTER 1
+#define CCTX_INTRA 1
 #define CCTX_DC_ONLY 0
 #define CCTX_PREC_BITS 8
-extern const int32_t cctx_mtx[4];
+extern const int32_t cctx_mtx[CCTX_TYPES - 1][2];
 #endif  // CONFIG_CROSS_CHROMA_TX
 
 #define MAX_TXFM_STAGE_NUM 12
diff --git a/av1/common/blockd.h b/av1/common/blockd.h
index b5bf2a9..aee81be 100644
--- a/av1/common/blockd.h
+++ b/av1/common/blockd.h
@@ -897,6 +897,12 @@
    * 'MACROBLOCK' structs.
    */
   TX_TYPE *tx_type_map;
+#if CONFIG_CROSS_CHROMA_TX
+  /*!
+   * Array of CCTX types.
+   */
+  CctxType *cctx_type_map;
+#endif  // CONFIG_CROSS_CHROMA_TX
   /*!
    * Stride for 'tx_type_map'. Note that this may / may not be same as
    * 'mi_stride', depending on which actual array 'tx_type_map' points to.
@@ -1555,6 +1561,44 @@
   }
 }
 
+#if CONFIG_CROSS_CHROMA_TX
+#if CCTX_C2_DROPPED
+static INLINE int keep_chroma_c2(CctxType cctx_type) {
+  return
+#if CCTX_NEG_ANGLES
+      cctx_type == CCTX_M30 || cctx_type == CCTX_M60 ||
+#endif
+#if CCTX_POS_ANGLES
+      cctx_type == CCTX_30 || cctx_type == CCTX_60 ||
+#endif
+      cctx_type == CCTX_NONE;
+}
+#endif
+
+static INLINE void update_cctx_array(MACROBLOCKD *const xd, int blk_row,
+                                     int blk_col, TX_SIZE tx_size,
+                                     CctxType cctx_type) {
+  const int stride = xd->tx_type_map_stride;
+  xd->cctx_type_map[blk_row * stride + blk_col] = cctx_type;
+
+  const int txw = tx_size_wide_unit[tx_size];
+  const int txh = tx_size_high_unit[tx_size];
+  // The 16x16 unit is due to the constraint from tx_64x64 which sets the
+  // maximum tx size for chroma as 32x32. Coupled with 4x1 transform block
+  // size, the constraint takes effect in 32x16 / 16x32 size too. To solve
+  // the intricacy, cover all the 16x16 units inside a 64 level transform.
+  if (txw == tx_size_wide_unit[TX_64X64] ||
+      txh == tx_size_high_unit[TX_64X64]) {
+    const int tx_unit = tx_size_wide_unit[TX_16X16];
+    for (int idy = 0; idy < txh; idy += tx_unit) {
+      for (int idx = 0; idx < txw; idx += tx_unit) {
+        xd->cctx_type_map[(blk_row + idy) * stride + blk_col + idx] = cctx_type;
+      }
+    }
+  }
+}
+#endif  // CONFIG_CROSS_CHROMA_TX
+
 #if CONFIG_IST
 static INLINE int tx_size_is_depth0(TX_SIZE tx_size, BLOCK_SIZE bsize) {
   TX_SIZE ctx_size = max_txsize_rect_lookup[bsize];
@@ -1733,6 +1777,13 @@
   return tx_type;
 }
 
+#if CONFIG_CROSS_CHROMA_TX
+static INLINE CctxType av1_get_cctx_type(const MACROBLOCKD *xd, int blk_row,
+                                         int blk_col) {
+  return xd->cctx_type_map[blk_row * xd->tx_type_map_stride + blk_col];
+}
+#endif  // CONFIG_CROSS_CHROMA_TX
+
 void av1_setup_block_planes(MACROBLOCKD *xd, int ss_x, int ss_y,
                             const int num_planes);
 
diff --git a/av1/common/entropy.c b/av1/common/entropy.c
index 8bc639f..86b6465 100644
--- a/av1/common/entropy.c
+++ b/av1/common/entropy.c
@@ -295,4 +295,7 @@
 #if CONFIG_IST
   RESET_CDF_COUNTER_STRIDE(fc->stx_cdf, STX_TYPES, CDF_SIZE(STX_TYPES));
 #endif
+#if CONFIG_CROSS_CHROMA_TX
+  RESET_CDF_COUNTER(fc->cctx_type_cdf, CCTX_TYPES);
+#endif  // CONFIG_CROSS_CHROMA_TX
 }
diff --git a/av1/common/entropymode.c b/av1/common/entropymode.c
index f99767b..e48fea4 100644
--- a/av1/common/entropymode.c
+++ b/av1/common/entropymode.c
@@ -683,6 +683,65 @@
                                                     };
 #endif  // CONFIG_DDT_INTER
 
+#if CONFIG_CROSS_CHROMA_TX
+static const aom_cdf_prob default_cctx_type_cdf[EXT_TX_SIZES]
+                                               [CDF_SIZE(CCTX_TYPES)] = {
+#if CCTX_ANGLE_CONFIG == 4
+#if CCTX_POS_ANGLES && CCTX_NEG_ANGLES
+                                                 { AOM_CDF3(10923, 21845) },
+                                                 { AOM_CDF3(10923, 21845) },
+                                                 { AOM_CDF3(10923, 21845) },
+                                                 { AOM_CDF3(10923, 21845) },
+#else
+                                                 { AOM_CDF2(16384) },
+                                                 { AOM_CDF2(16384) },
+                                                 { AOM_CDF2(16384) },
+                                                 { AOM_CDF2(16384) },
+#endif  // CCTX_POS_ANGLES && CCTX_NEG_ANGLES
+#elif CCTX_ANGLE_CONFIG == 3
+#if CCTX_POS_ANGLES && CCTX_NEG_ANGLES
+                                                 { AOM_CDF5(6554, 13107, 19661,
+                                                            26214) },
+                                                 { AOM_CDF5(6554, 13107, 19661,
+                                                            26214) },
+                                                 { AOM_CDF5(6554, 13107, 19661,
+                                                            26214) },
+                                                 { AOM_CDF5(6554, 13107, 19661,
+                                                            26214) },
+#else
+                                                 { AOM_CDF3(10923, 21845) },
+                                                 { AOM_CDF3(10923, 21845) },
+                                                 { AOM_CDF3(10923, 21845) },
+                                                 { AOM_CDF3(10923, 21845) },
+#endif  // CCTX_POS_ANGLES && CCTX_NEG_ANGLES
+#else
+#if CCTX_POS_ANGLES && CCTX_NEG_ANGLES
+                                                 { AOM_CDF7(4681, 9362, 14043,
+                                                            18725, 23406,
+                                                            28087) },
+                                                 { AOM_CDF7(4681, 9362, 14043,
+                                                            18725, 23406,
+                                                            28087) },
+                                                 { AOM_CDF7(4681, 9362, 14043,
+                                                            18725, 23406,
+                                                            28087) },
+                                                 { AOM_CDF7(4681, 9362, 14043,
+                                                            18725, 23406,
+                                                            28087) },
+#else
+                                                 { AOM_CDF4(8192, 16384,
+                                                            24576) },
+                                                 { AOM_CDF4(8192, 16384,
+                                                            24576) },
+                                                 { AOM_CDF4(8192, 16384,
+                                                            24576) },
+                                                 { AOM_CDF4(8192, 16384,
+                                                            24576) },
+#endif  // CCTX_POS_ANGLES && CCTX_NEG_ANGLES
+#endif
+                                               };
+#endif  // CONFIG_CROSS_CHROMA_TX
+
 static const aom_cdf_prob default_cfl_sign_cdf[CDF_SIZE(CFL_JOINT_SIGNS)] = {
   AOM_CDF8(1418, 2123, 13340, 18405, 26972, 28343, 32294)
 };
@@ -1943,6 +2002,9 @@
 #if CONFIG_IST
   av1_copy(fc->stx_cdf, default_stx_cdf);
 #endif
+#if CONFIG_CROSS_CHROMA_TX
+  av1_copy(fc->cctx_type_cdf, default_cctx_type_cdf);
+#endif  // CONFIG_CROSS_CHROMA_TX
 }
 
 void av1_set_default_ref_deltas(int8_t *ref_deltas) {
diff --git a/av1/common/entropymode.h b/av1/common/entropymode.h
index 2955367..c2267a0 100644
--- a/av1/common/entropymode.h
+++ b/av1/common/entropymode.h
@@ -272,6 +272,9 @@
 #if CONFIG_IST
   aom_cdf_prob stx_cdf[TX_SIZES][CDF_SIZE(STX_TYPES)];
 #endif
+#if CONFIG_CROSS_CHROMA_TX
+  aom_cdf_prob cctx_type_cdf[EXT_TX_SIZES][CDF_SIZE(CCTX_TYPES)];
+#endif  // CONFIG_CROSS_CHROMA_TX
   int initialized;
 } FRAME_CONTEXT;
 
diff --git a/av1/common/enums.h b/av1/common/enums.h
index f2b097c..9f33c0c 100644
--- a/av1/common/enums.h
+++ b/av1/common/enums.h
@@ -403,6 +403,47 @@
   DCT_ADST_TX_MASK = 0x000F,             // Either DCT or ADST in each direction
 } UENUM1BYTE(TX_TYPE);
 
+#if CONFIG_CROSS_CHROMA_TX
+#define CCTX_NEG_ANGLES 1
+#define CCTX_POS_ANGLES 1
+// Always signal C1 coefficients for some cctx (i.e., both C1 and C2 nonzero
+// or C1 nonzero and C2 zero). This requires CCTX_NEG_ANGLES to be on.
+#define CCTX_C1_NONZERO 1
+// Drop C2 channel for some cctx_types. This macro requires CCTX_C1_NONZERO to
+// be on.
+#define CCTX_C2_DROPPED 0
+// Configuration for the set of rotation angles
+// 0: { 45, 30, 60 }
+// 1: { 45, 22.5, 67.5 }
+// 2: { 45, 25, 65 }
+// 3: { 30, 60 }
+// 4: { 45 }
+#define CCTX_ANGLE_CONFIG 0
+enum {
+  CCTX_NONE,  // No cross chroma transform
+#if CCTX_POS_ANGLES
+#if CCTX_ANGLE_CONFIG != 3
+  CCTX_45,  // 45 degrees rotation (Haar transform)
+#endif
+#if CCTX_ANGLE_CONFIG != 4
+  CCTX_30,  // 30 degrees rotation
+  CCTX_60,  // 60 degrees rotation
+#endif
+#endif  // CCTX_POS_ANGLES
+#if CCTX_NEG_ANGLES
+#if CCTX_ANGLE_CONFIG != 3
+  CCTX_M45,  // -45 degrees rotation
+#endif
+#if CCTX_ANGLE_CONFIG != 4
+  CCTX_M30,  // -30 degrees rotation
+  CCTX_M60,  // -60 degrees rotation
+#endif
+#endif  // CCTX_NEG_ANGLES
+  CCTX_TYPES,
+  CCTX_START = CCTX_NONE + 1,
+} UENUM1BYTE(CctxType);
+#endif  // CONFIG_CROSS_CHROMA_TX
+
 enum {
   REG_REG,
   REG_SMOOTH,
diff --git a/av1/common/idct.c b/av1/common/idct.c
index 7e56a68..42378bc 100644
--- a/av1/common/idct.c
+++ b/av1/common/idct.c
@@ -322,7 +322,8 @@
 
 #if CONFIG_CROSS_CHROMA_TX
 void av1_inv_cross_chroma_tx_block(tran_low_t *dqcoeff_u, tran_low_t *dqcoeff_v,
-                                   TX_SIZE tx_size) {
+                                   TX_SIZE tx_size, CctxType cctx_type) {
+  if (cctx_type == CCTX_NONE) return;
 #if CCTX_DC_ONLY
   const int ncoeffs = 1;
 #else
@@ -334,9 +335,12 @@
   int32_t *src_v = (int32_t *)dqcoeff_v;
   int32_t tmp[2] = { 0, 0 };
 
+  const int angle_idx = cctx_type - CCTX_START;
   for (int i = 0; i < ncoeffs; i++) {
-    tmp[0] = cctx_mtx[0] * src_u[i] + cctx_mtx[2] * src_v[i];
-    tmp[1] = cctx_mtx[1] * src_u[i] + cctx_mtx[3] * src_v[i];
+    tmp[0] =
+        cctx_mtx[angle_idx][0] * src_u[i] - cctx_mtx[angle_idx][1] * src_v[i];
+    tmp[1] =
+        cctx_mtx[angle_idx][1] * src_u[i] + cctx_mtx[angle_idx][0] * src_v[i];
     src_u[i] = ROUND_POWER_OF_TWO_SIGNED(tmp[0], CCTX_PREC_BITS);
     src_v[i] = ROUND_POWER_OF_TWO_SIGNED(tmp[1], CCTX_PREC_BITS);
   }
diff --git a/av1/common/idct.h b/av1/common/idct.h
index 62b5c62..4057fc3 100644
--- a/av1/common/idct.h
+++ b/av1/common/idct.h
@@ -35,7 +35,7 @@
 
 #if CONFIG_CROSS_CHROMA_TX
 void av1_inv_cross_chroma_tx_block(tran_low_t *dqcoeff_u, tran_low_t *dqcoeff_v,
-                                   TX_SIZE tx_size);
+                                   TX_SIZE tx_size, CctxType cctx_type);
 #endif  // CONFIG_CROSS_CHROMA_TX
 
 void av1_inverse_transform_block(const MACROBLOCKD *xd,
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index 1c4664d..ba3490a 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -248,7 +248,20 @@
   av1_predict_intra_block_facade(cm, xd, plane, col, row, tx_size);
   if (!mbmi->skip_txfm[xd->tree_type == CHROMA_PART]) {
     eob_info *eob_data = dcb->eob_data[plane] + dcb->txb_offset[plane];
+#if CONFIG_CROSS_CHROMA_TX && CCTX_INTRA
+    eob_info *eob_data_u =
+        dcb->eob_data[AOM_PLANE_U] + dcb->txb_offset[AOM_PLANE_U];
+#if CCTX_C1_NONZERO
+    if (eob_data->eob || (av1_get_cctx_type(xd, row, col) > CCTX_NONE &&
+                          plane == AOM_PLANE_V && eob_data_u->eob)) {
+#else
+    eob_info *eob_data_v =
+        dcb->eob_data[AOM_PLANE_V] + dcb->txb_offset[AOM_PLANE_V];
+    if (eob_data->eob || (plane && (eob_data_u->eob || eob_data_v->eob))) {
+#endif
+#else
     if (eob_data->eob) {
+#endif  // CONFIG_CROSS_CHROMA_TX
       const bool reduced_tx_set_used = cm->features.reduced_tx_set_used;
       // tx_type was read out in av1_read_coeffs_txb.
       const TX_TYPE tx_type = av1_get_tx_type(xd, plane_type, row, col, tx_size,
@@ -273,13 +286,13 @@
   (void)cm;
   (void)r;
   (void)plane;
-  (void)blk_row;
-  (void)blk_col;
   tran_low_t *dqcoeff_u =
       dcb->dqcoeff_block[AOM_PLANE_U] + dcb->cb_offset[AOM_PLANE_U];
   tran_low_t *dqcoeff_v =
       dcb->dqcoeff_block[AOM_PLANE_V] + dcb->cb_offset[AOM_PLANE_V];
-  av1_inv_cross_chroma_tx_block(dqcoeff_u, dqcoeff_v, tx_size);
+  MACROBLOCKD *const xd = &dcb->xd;
+  const CctxType cctx_type = av1_get_cctx_type(xd, blk_row, blk_col);
+  av1_inv_cross_chroma_tx_block(dqcoeff_u, dqcoeff_v, tx_size, cctx_type);
 }
 #endif  // CONFIG_CROSS_CHROMA_TX
 
@@ -325,7 +338,7 @@
     AV1_COMMON *cm, ThreadData *const td, aom_reader *r,
     MB_MODE_INFO *const mbmi, int plane, BLOCK_SIZE plane_bsize, int blk_row,
     int blk_col, int block, TX_SIZE tx_size, int *eob_total) {
-#if CONFIG_CROSS_CHROMA_TX
+#if CONFIG_CROSS_CHROMA_TX && CCTX_INTER
   if (plane == AOM_PLANE_U) return;
 #endif  // CONFIG_CROSS_CHROMA_TX
   DecoderCodingBlock *const dcb = &td->dcb;
@@ -345,48 +358,38 @@
   if (blk_row >= max_blocks_high || blk_col >= max_blocks_wide) return;
 
   if (tx_size == plane_tx_size || plane) {
-#if CONFIG_CROSS_CHROMA_TX
-    switch (plane) {
-      case AOM_PLANE_Y:
-        td->read_coeffs_tx_inter_block_visit(cm, dcb, r, plane, blk_row,
-                                             blk_col, tx_size);
+#if CONFIG_CROSS_CHROMA_TX && CCTX_INTER
+    if (plane == AOM_PLANE_V) {
+      td->read_coeffs_tx_inter_block_visit(cm, dcb, r, AOM_PLANE_U, blk_row,
+                                           blk_col, tx_size);
+      td->read_coeffs_tx_inter_block_visit(cm, dcb, r, AOM_PLANE_V, blk_row,
+                                           blk_col, tx_size);
+      td->inverse_cctx_block_visit(cm, dcb, r, -1, blk_row, blk_col, tx_size);
+      td->inverse_tx_inter_block_visit(cm, dcb, r, AOM_PLANE_U, blk_row,
+                                       blk_col, tx_size);
+      td->inverse_tx_inter_block_visit(cm, dcb, r, AOM_PLANE_V, blk_row,
+                                       blk_col, tx_size);
+      eob_info *eob_data_u =
+          dcb->eob_data[AOM_PLANE_U] + dcb->txb_offset[AOM_PLANE_U];
+      eob_info *eob_data_v =
+          dcb->eob_data[AOM_PLANE_V] + dcb->txb_offset[AOM_PLANE_V];
+      *eob_total += eob_data_u->eob + eob_data_v->eob;
+      set_cb_buffer_offsets(dcb, tx_size, AOM_PLANE_U);
+      set_cb_buffer_offsets(dcb, tx_size, AOM_PLANE_V);
+    } else {
+      assert(plane == AOM_PLANE_Y);
+#endif  // CONFIG_CROSS_CHROMA_TX && CCTX_INTER
+      td->read_coeffs_tx_inter_block_visit(cm, dcb, r, plane, blk_row, blk_col,
+                                           tx_size);
 
-        td->inverse_tx_inter_block_visit(cm, dcb, r, plane, blk_row, blk_col,
-                                         tx_size);
-        eob_info *eob_data = dcb->eob_data[plane] + dcb->txb_offset[plane];
-        *eob_total += eob_data->eob;
-        set_cb_buffer_offsets(dcb, tx_size, plane);
-        break;
-      case AOM_PLANE_V:
-        td->read_coeffs_tx_inter_block_visit(cm, dcb, r, AOM_PLANE_U, blk_row,
-                                             blk_col, tx_size);
-        td->read_coeffs_tx_inter_block_visit(cm, dcb, r, AOM_PLANE_V, blk_row,
-                                             blk_col, tx_size);
-        td->inverse_cctx_block_visit(cm, dcb, r, -1, blk_row, blk_col, tx_size);
-        td->inverse_tx_inter_block_visit(cm, dcb, r, AOM_PLANE_U, blk_row,
-                                         blk_col, tx_size);
-        td->inverse_tx_inter_block_visit(cm, dcb, r, AOM_PLANE_V, blk_row,
-                                         blk_col, tx_size);
-        eob_info *eob_data_u =
-            dcb->eob_data[AOM_PLANE_U] + dcb->txb_offset[AOM_PLANE_U];
-        eob_info *eob_data_v =
-            dcb->eob_data[AOM_PLANE_V] + dcb->txb_offset[AOM_PLANE_V];
-        *eob_total += eob_data_u->eob + eob_data_v->eob;
-        set_cb_buffer_offsets(dcb, tx_size, AOM_PLANE_U);
-        set_cb_buffer_offsets(dcb, tx_size, AOM_PLANE_V);
-        break;
-      case AOM_PLANE_U: assert(0); break;
+      td->inverse_tx_inter_block_visit(cm, dcb, r, plane, blk_row, blk_col,
+                                       tx_size);
+      eob_info *eob_data = dcb->eob_data[plane] + dcb->txb_offset[plane];
+      *eob_total += eob_data->eob;
+      set_cb_buffer_offsets(dcb, tx_size, plane);
+#if CONFIG_CROSS_CHROMA_TX && CCTX_INTER
     }
-#else
-    td->read_coeffs_tx_inter_block_visit(cm, dcb, r, plane, blk_row, blk_col,
-                                         tx_size);
-
-    td->inverse_tx_inter_block_visit(cm, dcb, r, plane, blk_row, blk_col,
-                                     tx_size);
-    eob_info *eob_data = dcb->eob_data[plane] + dcb->txb_offset[plane];
-    *eob_total += eob_data->eob;
-    set_cb_buffer_offsets(dcb, tx_size, plane);
-#endif  // CONFIG_CROSS_CHROMA_TX
+#endif  // CONFIG_CROSS_CHROMA_TX && CCTX_INTER
   } else {
 #if CONFIG_NEW_TX_PARTITION
     TX_SIZE sub_txs[MAX_TX_PARTITIONS] = { 0 };
@@ -1258,6 +1261,9 @@
     for (col = 0; col < max_blocks_wide; col += mu_blocks_wide) {
       for (int plane = plane_start; plane < plane_end; ++plane) {
         if (plane && !xd->is_chroma_ref) break;
+#if CONFIG_CROSS_CHROMA_TX && CCTX_INTRA
+        if (!is_inter && plane == AOM_PLANE_U) continue;
+#endif  // CONFIG_CROSS_CHROMA_TX
         const struct macroblockd_plane *const pd = &xd->plane[plane];
         const TX_SIZE tx_size = av1_get_tx_size(plane, xd);
         const int ss_x = pd->subsampling_x;
@@ -1281,30 +1287,31 @@
           for (int blk_col = col >> ss_x; blk_col < unit_width;
                blk_col += stepc) {
             if (!is_inter) {
-              td->read_coeffs_tx_intra_block_visit(cm, dcb, r, plane, blk_row,
-                                                   blk_col, tx_size);
-#if CONFIG_CROSS_CHROMA_TX
-              switch (plane) {
-                case AOM_PLANE_Y:
-                  td->predict_and_recon_intra_block_visit(
-                      cm, dcb, r, AOM_PLANE_Y, blk_row, blk_col, tx_size);
-                  set_cb_buffer_offsets(dcb, tx_size, AOM_PLANE_Y);
-                  break;
-                case AOM_PLANE_U: break;
-                case AOM_PLANE_V:
-                  td->predict_and_recon_intra_block_visit(
-                      cm, dcb, r, AOM_PLANE_U, blk_row, blk_col, tx_size);
-                  set_cb_buffer_offsets(dcb, tx_size, AOM_PLANE_U);
-                  td->predict_and_recon_intra_block_visit(
-                      cm, dcb, r, AOM_PLANE_V, blk_row, blk_col, tx_size);
-                  set_cb_buffer_offsets(dcb, tx_size, AOM_PLANE_V);
-                  break;
+#if CONFIG_CROSS_CHROMA_TX && CCTX_INTRA
+              if (plane == AOM_PLANE_V) {
+                td->read_coeffs_tx_intra_block_visit(cm, dcb, r, AOM_PLANE_U,
+                                                     blk_row, blk_col, tx_size);
+                td->read_coeffs_tx_intra_block_visit(cm, dcb, r, AOM_PLANE_V,
+                                                     blk_row, blk_col, tx_size);
+                td->inverse_cctx_block_visit(cm, dcb, r, -1, blk_row, blk_col,
+                                             tx_size);
+                td->predict_and_recon_intra_block_visit(
+                    cm, dcb, r, AOM_PLANE_U, blk_row, blk_col, tx_size);
+                td->predict_and_recon_intra_block_visit(
+                    cm, dcb, r, AOM_PLANE_V, blk_row, blk_col, tx_size);
+                set_cb_buffer_offsets(dcb, tx_size, AOM_PLANE_U);
+                set_cb_buffer_offsets(dcb, tx_size, AOM_PLANE_V);
+              } else {
+                assert(plane == AOM_PLANE_Y);
+#endif  // CONFIG_CROSS_CHROMA_TX && CCTX_INTRA
+                td->read_coeffs_tx_intra_block_visit(cm, dcb, r, plane, blk_row,
+                                                     blk_col, tx_size);
+                td->predict_and_recon_intra_block_visit(
+                    cm, dcb, r, plane, blk_row, blk_col, tx_size);
+                set_cb_buffer_offsets(dcb, tx_size, plane);
+#if CONFIG_CROSS_CHROMA_TX && CCTX_INTRA
               }
-#else
-              td->predict_and_recon_intra_block_visit(
-                  cm, dcb, r, plane, blk_row, blk_col, tx_size);
-              set_cb_buffer_offsets(dcb, tx_size, plane);
-#endif  // CONFIG_CROSS_CHROMA_TX
+#endif  // CONFIG_CROSS_CHROMA_TX && CCTX_INTRA
             } else {
               // Reconstruction
               if (!mbmi->skip_txfm[xd->tree_type == CHROMA_PART]) {
@@ -1764,6 +1771,10 @@
   xd->mi = mi_params->mi_grid_base + offset;
   xd->tx_type_map =
       &mi_params->tx_type_map[mi_row * mi_params->mi_stride + mi_col];
+#if CONFIG_CROSS_CHROMA_TX
+  xd->cctx_type_map =
+      &mi_params->cctx_type_map[mi_row * mi_params->mi_stride + mi_col];
+#endif  // CONFIG_CROSS_CHROMA_TX
   xd->tx_type_map_stride = mi_params->mi_stride;
 
   set_plane_n4(xd, bw, bh, num_planes);
diff --git a/av1/decoder/decodemv.c b/av1/decoder/decodemv.c
index 4e74c41..5181fcd 100644
--- a/av1/decoder/decodemv.c
+++ b/av1/decoder/decodemv.c
@@ -869,6 +869,34 @@
   }
 }
 
+#if CONFIG_CROSS_CHROMA_TX
+void av1_read_cctx_type(const AV1_COMMON *const cm, MACROBLOCKD *xd,
+                        int blk_row, int blk_col, TX_SIZE tx_size,
+                        aom_reader *r) {
+  MB_MODE_INFO *mbmi = xd->mi[0];
+  CctxType *cctx_type =
+      &xd->cctx_type_map[blk_row * xd->tx_type_map_stride + blk_col];
+  *cctx_type = CCTX_NONE;
+
+  // No need to read transform type if block is skipped.
+  if (mbmi->skip_txfm[xd->tree_type == CHROMA_PART] ||
+      segfeature_active(&cm->seg, mbmi->segment_id, SEG_LVL_SKIP))
+    return;
+
+  // No need to read transform type for lossless mode(qindex==0).
+  const int qindex = xd->qindex[mbmi->segment_id];
+  if (qindex == 0) return;
+
+  const int is_inter = is_inter_block(mbmi, xd->tree_type);
+  if ((is_inter && CCTX_INTER) || (!is_inter && CCTX_INTRA)) {
+    FRAME_CONTEXT *ec_ctx = xd->tile_ctx;
+    const TX_SIZE square_tx_size = txsize_sqr_map[tx_size];
+    *cctx_type = aom_read_symbol(r, ec_ctx->cctx_type_cdf[square_tx_size],
+                                 CCTX_TYPES, ACCT_STR);
+  }
+}
+#endif  // CONFIG_CROSS_CHROMA_TX
+
 #if CONFIG_IST
 void av1_read_sec_tx_type(const AV1_COMMON *const cm, MACROBLOCKD *xd,
                           int blk_row, int blk_col, TX_SIZE tx_size,
diff --git a/av1/decoder/decodemv.h b/av1/decoder/decodemv.h
index e8d40ba..ad34ec9 100644
--- a/av1/decoder/decodemv.h
+++ b/av1/decoder/decodemv.h
@@ -37,4 +37,10 @@
 void av1_read_tx_type(const AV1_COMMON *const cm, MACROBLOCKD *xd, int blk_row,
                       int blk_col, TX_SIZE tx_size, aom_reader *r);
 
+#if CONFIG_CROSS_CHROMA_TX
+void av1_read_cctx_type(const AV1_COMMON *const cm, MACROBLOCKD *xd,
+                        int blk_row, int blk_col, TX_SIZE tx_size,
+                        aom_reader *r);
+#endif  // CONFIG_CROSS_CHROMA_TX
+
 #endif  // AOM_AV1_DECODER_DECODEMV_H_
diff --git a/av1/decoder/decoder.c b/av1/decoder/decoder.c
index 12b8201..0558a17 100644
--- a/av1/decoder/decoder.c
+++ b/av1/decoder/decoder.c
@@ -113,6 +113,10 @@
   mi_params->mi_alloc_size = 0;
   aom_free(mi_params->tx_type_map);
   mi_params->tx_type_map = NULL;
+#if CONFIG_CROSS_CHROMA_TX
+  aom_free(mi_params->cctx_type_map);
+  mi_params->cctx_type_map = NULL;
+#endif  // CONFIG_CROSS_CHROMA_TX
 }
 
 #if CONFIG_TIP
diff --git a/av1/decoder/decodetxb.c b/av1/decoder/decodetxb.c
index 53e9b34..779408f 100644
--- a/av1/decoder/decodetxb.c
+++ b/av1/decoder/decodetxb.c
@@ -143,6 +143,20 @@
   MACROBLOCKD *const xd = &dcb->xd;
   FRAME_CONTEXT *const ec_ctx = xd->tile_ctx;
   const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
+
+  eob_info *eob_data = dcb->eob_data[plane] + dcb->txb_offset[plane];
+  uint16_t *const eob = &(eob_data->eob);
+  uint16_t *const max_scan_line = &(eob_data->max_scan_line);
+  *max_scan_line = 0;
+  *eob = 0;
+
+#if CONFIG_CROSS_CHROMA_TX && CCTX_C2_DROPPED
+  if (plane == AOM_PLANE_V) {
+    CctxType cctx_type = av1_get_cctx_type(xd, blk_row, blk_col);
+    if (!keep_chroma_c2(cctx_type)) return 0;
+  }
+#endif  // CONFIG_CROSS_CHROMA_TX && CCTX_C2_DROPPED
+
 #if CONFIG_CONTEXT_DERIVATION
   if (plane == AOM_PLANE_U) {
     xd->eob_u = 0;
@@ -162,12 +176,6 @@
       r, ec_ctx->txb_skip_cdf[txs_ctx][txb_ctx->txb_skip_ctx], 2, ACCT_STR);
 #endif  // CONFIG_CONTEXT_DERIVATION
 
-  eob_info *eob_data = dcb->eob_data[plane] + dcb->txb_offset[plane];
-  uint16_t *const eob = &(eob_data->eob);
-  uint16_t *const max_scan_line = &(eob_data->max_scan_line);
-  *max_scan_line = 0;
-  *eob = 0;
-
 #if CONFIG_INSPECTION
   MB_MODE_INFO *const mbmi = xd->mi[0];
   if (plane == 0) {
@@ -183,6 +191,26 @@
   }
 #endif  // CONFIG_CONTEXT_DERIVATION
 
+#if CONFIG_CROSS_CHROMA_TX
+#if CCTX_C1_NONZERO
+  if (plane == AOM_PLANE_U) {
+    if (!all_zero)
+      av1_read_cctx_type(cm, xd, blk_row, blk_col, tx_size, r);
+    else
+      xd->cctx_type_map[blk_row * xd->tx_type_map_stride + blk_col] = CCTX_NONE;
+  }
+#else
+  if (plane == AOM_PLANE_V) {
+    // cctx_type will be read either eob_v > 0 or eob_u > 0
+    eob_info *eob_data_u =
+        dcb->eob_data[AOM_PLANE_U] + dcb->txb_offset[AOM_PLANE_U];
+    const uint16_t eob_u = eob_data_u->eob;
+    if (!all_zero || eob_u > 0)
+      av1_read_cctx_type(cm, xd, blk_row, blk_col, tx_size, r);
+  }
+#endif  // CCTX_C1_NONZERO
+#endif  // CONFIG_CROSS_CHROMA_TX
+
   if (all_zero) {
     *max_scan_line = 0;
     if (plane == 0) {
@@ -312,6 +340,12 @@
   uint8_t levels_buf[TX_PAD_2D];
   uint8_t *const levels = set_levels(levels_buf, width);
 #if !CONFIG_FORWARDSKIP
+#if CONFIG_CROSS_CHROMA_TX && CCTX_C2_DROPPED
+  if (plane == AOM_PLANE_V) {
+    CctxType cctx_type = av1_get_cctx_type(xd, blk_row, blk_col);
+    if (!keep_chroma_c2(cctx_type)) return 0;
+  }
+#endif  // CONFIG_CROSS_CHROMA_TX && CCTX_C2_DROPPED
 #if CONFIG_CONTEXT_DERIVATION
   int txb_skip_ctx = txb_ctx->txb_skip_ctx;
   int all_zero;
@@ -327,6 +361,22 @@
   const int all_zero = aom_read_symbol(
       r, ec_ctx->txb_skip_cdf[txs_ctx][txb_ctx->txb_skip_ctx], 2, ACCT_STR);
 #endif  // CONFIG_CONTEXT_DERIVATION
+
+#if CONFIG_CROSS_CHROMA_TX
+#if CCTX_C1_NONZERO
+  if (plane == AOM_PLANE_U && !all_zero) {
+    av1_read_cctx_type(cm, xd, blk_row, blk_col, tx_size, r);
+  }
+#else
+  if (plane == AOM_PLANE_V) {
+    eob_info *eob_data_u =
+        dcb->eob_data[AOM_PLANE_U] + dcb->txb_offset[AOM_PLANE_U];
+    uint16_t eob_u = eob_data_u->eob;
+    if (!all_zero || eob_u > 0)
+      av1_read_cctx_type(cm, xd, blk_row, blk_col, tx_size, r);
+  }
+#endif  // CCTX_C1_NONZERO
+#endif  // CONFIG_CROSS_CHROMA_TX
 #endif  // CONFIG_FORWARDSKIP
   eob_info *eob_data = dcb->eob_data[plane] + dcb->txb_offset[plane];
   uint16_t *const eob = &(eob_data->eob);
@@ -672,6 +722,9 @@
         for (int idy = 0; idy < txh; idy += tx_unit) {
           for (int idx = 0; idx < txw; idx += tx_unit) {
             xd->tx_type_map[(row + idy) * stride + col + idx] = tx_type;
+#if CONFIG_CROSS_CHROMA_TX
+            xd->cctx_type_map[(row + idy) * stride + col + idx] = CCTX_NONE;
+#endif  // CONFIG_CROSS_CHROMA_TX
           }
         }
       }
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index da8c9a2..c062195 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -510,6 +510,8 @@
     }
   }
 #else
+  (void)xd;
+  (void)mbmi;
   av1_write_coeffs_txb(cm, x, w, blk_row, blk_col, plane, block, tx_size);
 #endif  // CONFIG_FORWARDSKIP
 }
@@ -1156,6 +1158,25 @@
   }
 }
 
+#if CONFIG_CROSS_CHROMA_TX
+void av1_write_cctx_type(const AV1_COMMON *const cm, const MACROBLOCKD *xd,
+                         CctxType cctx_type, TX_SIZE tx_size, aom_writer *w) {
+  MB_MODE_INFO *mbmi = xd->mi[0];
+  assert(xd->is_chroma_ref);
+  const int is_inter = is_inter_block(mbmi, xd->tree_type);
+  if (((!cm->seg.enabled && cm->quant_params.base_qindex > 0) ||
+       (cm->seg.enabled && xd->qindex[mbmi->segment_id] > 0)) &&
+      ((is_inter && CCTX_INTER) || (!is_inter && CCTX_INTRA)) &&
+      !mbmi->skip_txfm[xd->tree_type == CHROMA_PART] &&
+      !segfeature_active(&cm->seg, mbmi->segment_id, SEG_LVL_SKIP)) {
+    FRAME_CONTEXT *ec_ctx = xd->tile_ctx;
+    const TX_SIZE square_tx_size = txsize_sqr_map[tx_size];
+    aom_write_symbol(w, cctx_type, ec_ctx->cctx_type_cdf[square_tx_size],
+                     CCTX_TYPES);
+  }
+}
+#endif  // CONFIG_CROSS_CHROMA_TX
+
 #if CONFIG_IST
 void av1_write_sec_tx_type(const AV1_COMMON *const cm, const MACROBLOCKD *xd,
                            TX_TYPE tx_type, TX_SIZE tx_size, uint16_t eob,
@@ -2148,6 +2169,9 @@
       get_mi_ext_idx(mi_row, mi_col, cm->mi_params.mi_alloc_bsize,
                      cpi->mbmi_ext_info.stride);
   xd->tx_type_map = mi_params->tx_type_map + grid_idx;
+#if CONFIG_CROSS_CHROMA_TX
+  xd->cctx_type_map = mi_params->cctx_type_map + grid_idx;
+#endif  // CONFIG_CROSS_CHROMA_TX
   xd->tx_type_map_stride = mi_params->mi_stride;
 
   MB_MODE_INFO *mbmi = xd->mi[0];
diff --git a/av1/encoder/bitstream.h b/av1/encoder/bitstream.h
index 484c58c..8a7cd65 100644
--- a/av1/encoder/bitstream.h
+++ b/av1/encoder/bitstream.h
@@ -53,6 +53,11 @@
 void av1_write_tx_type(const AV1_COMMON *const cm, const MACROBLOCKD *xd,
                        TX_TYPE tx_type, TX_SIZE tx_size, aom_writer *w);
 
+#if CONFIG_CROSS_CHROMA_TX
+void av1_write_cctx_type(const AV1_COMMON *const cm, const MACROBLOCKD *xd,
+                         CctxType cctx_type, TX_SIZE tx_size, aom_writer *w);
+#endif  // CONFIG_CROSS_CHROMA_TX
+
 #ifdef __cplusplus
 }  // extern "C"
 #endif
diff --git a/av1/encoder/block.h b/av1/encoder/block.h
index 6e97fee..01f70d5 100644
--- a/av1/encoder/block.h
+++ b/av1/encoder/block.h
@@ -277,6 +277,10 @@
   uint8_t blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE];
   //! Map showing the txfm types for each blcok.
   TX_TYPE tx_type_map[MAX_MIB_SIZE * MAX_MIB_SIZE];
+#if CONFIG_CROSS_CHROMA_TX
+  //! Map showing the cctx types for each block.
+  TX_TYPE cctx_type_map[MAX_MIB_SIZE * MAX_MIB_SIZE];
+#endif  // CONFIG_CROSS_CHROMA_TX
   //! Rd_stats for the whole partition block.
   RD_STATS rd_stats;
   //! Hash value of the current record.
@@ -549,6 +553,10 @@
    * primary tx_type
    */
   TX_TYPE tx_type_map_[MAX_MIB_SIZE * MAX_MIB_SIZE];
+#if CONFIG_CROSS_CHROMA_TX
+  //! \brief CCTX types inside the partition block.
+  CctxType cctx_type_map_[MAX_MIB_SIZE * MAX_MIB_SIZE];
+#endif  // CONFIG_CROSS_CHROMA_TX
 
   /** \name Txfm hash records
    * Hash records of the transform search results based on the residue. There
@@ -845,6 +853,10 @@
   int intra_tx_type_costs[EXT_TX_SETS_INTRA][EXT_TX_SIZES][INTRA_MODES]
                          [TX_TYPES];
 #endif  // CONFIG_DDT_INTER
+#if CONFIG_CROSS_CHROMA_TX
+  //! cctx_type_cost
+  int cctx_type_cost[EXT_TX_SIZES][CCTX_TYPES];
+#endif  // CONFIG_CROSS_CHROMA_TX
   /**@}*/
 
   /*****************************************************************************
diff --git a/av1/encoder/context_tree.c b/av1/encoder/context_tree.c
index 99e18e6..6871c1f 100644
--- a/av1/encoder/context_tree.c
+++ b/av1/encoder/context_tree.c
@@ -32,6 +32,10 @@
          sizeof(uint8_t) * src_ctx->num_4x4_blk);
   av1_copy_array(dst_ctx->tx_type_map, src_ctx->tx_type_map,
                  src_ctx->num_4x4_blk);
+#if CONFIG_CROSS_CHROMA_TX
+  av1_copy_array(dst_ctx->cctx_type_map, src_ctx->cctx_type_map,
+                 src_ctx->num_4x4_blk);
+#endif  // CONFIG_CROSS_CHROMA_TX
 
   dst_ctx->hybrid_pred_diff = src_ctx->hybrid_pred_diff;
   dst_ctx->comp_pred_diff = src_ctx->comp_pred_diff;
@@ -81,6 +85,10 @@
                       aom_calloc(num_blk, sizeof(*ctx->blk_skip)));
   AOM_CHECK_MEM_ERROR(&error, ctx->tx_type_map,
                       aom_calloc(num_blk, sizeof(*ctx->tx_type_map)));
+#if CONFIG_CROSS_CHROMA_TX
+  AOM_CHECK_MEM_ERROR(&error, ctx->cctx_type_map,
+                      aom_calloc(num_blk, sizeof(*ctx->cctx_type_map)));
+#endif  // CONFIG_CROSS_CHROMA_TX
   ctx->num_4x4_blk = num_blk;
 
   for (int i = 0; i < num_planes; ++i) {
@@ -111,6 +119,9 @@
   aom_free(ctx->blk_skip);
   ctx->blk_skip = NULL;
   aom_free(ctx->tx_type_map);
+#if CONFIG_CROSS_CHROMA_TX
+  aom_free(ctx->cctx_type_map);
+#endif  // CONFIG_CROSS_CHROMA_TX
   for (int i = 0; i < num_planes; ++i) {
     ctx->coeff[i] = NULL;
     ctx->qcoeff[i] = NULL;
diff --git a/av1/encoder/context_tree.h b/av1/encoder/context_tree.h
index 83c37bf..b48a78a 100644
--- a/av1/encoder/context_tree.h
+++ b/av1/encoder/context_tree.h
@@ -45,6 +45,9 @@
   uint16_t *eobs[MAX_MB_PLANE];
   uint8_t *txb_entropy_ctx[MAX_MB_PLANE];
   TX_TYPE *tx_type_map;
+#if CONFIG_CROSS_CHROMA_TX
+  CctxType *cctx_type_map;
+#endif  // CONFIG_CROSS_CHROMA_TX
 
   int num_4x4_blk;
   // For current partition, only if all Y, U, and V transform blocks'
diff --git a/av1/encoder/encodeframe_utils.c b/av1/encoder/encodeframe_utils.c
index b96b734..dcb753e 100644
--- a/av1/encoder/encodeframe_utils.c
+++ b/av1/encoder/encodeframe_utils.c
@@ -171,6 +171,10 @@
   for (int row = 0; row < mi_size_high[mbmi->sb_type[plane_index]]; ++row) {
     memset(xd->tx_type_map + row * stride, DCT_DCT,
            bw * sizeof(xd->tx_type_map[0]));
+#if CONFIG_CROSS_CHROMA_TX
+    memset(xd->cctx_type_map + row * stride, CCTX_NONE,
+           bw * sizeof(xd->cctx_type_map[0]));
+#endif  // CONFIG_CROSS_CHROMA_TX
   }
   av1_zero(txfm_info->blk_skip);
   txfm_info->skip_txfm = 0;
@@ -242,6 +246,25 @@
     }
   }
 
+#if CONFIG_CROSS_CHROMA_TX
+  if (xd->tree_type != LUMA_PART) {
+    xd->cctx_type_map = ctx->cctx_type_map;
+    xd->tx_type_map_stride = mi_size_wide[bsize];
+    if (!dry_run) {
+      const int grid_idx = get_mi_grid_idx(mi_params, mi_row, mi_col);
+      CctxType *const cctx_type_map = mi_params->cctx_type_map + grid_idx;
+      const int mi_stride = mi_params->mi_stride;
+      for (int blk_row = 0; blk_row < bh; ++blk_row) {
+        av1_copy_array(cctx_type_map + blk_row * mi_stride,
+                       xd->cctx_type_map + blk_row * xd->tx_type_map_stride,
+                       bw);
+      }
+      xd->cctx_type_map = cctx_type_map;
+      xd->tx_type_map_stride = mi_stride;
+    }
+  }
+#endif
+
   // If segmentation in use
   if (seg->enabled) {
     // For in frame complexity AQ copy the segment id from the segment map.
@@ -1335,6 +1358,10 @@
   AVG_CDF_STRIDE(ctx_left->stx_cdf, ctx_tr->stx_cdf, STX_TYPES,
                  CDF_SIZE(STX_TYPES));
 #endif
+#if CONFIG_CROSS_CHROMA_TX
+  AVG_CDF_STRIDE(ctx_left->cctx_type_cdf, ctx_tr->cctx_type_cdf, CCTX_TYPES,
+                 CDF_SIZE(CCTX_TYPES));
+#endif  // CONFIG_CROSS_CHROMA_TX
 }
 
 // Memset the mbmis at the current superblock to 0
@@ -1362,6 +1389,10 @@
            sb_size_mi * sizeof(*mi_params->mi_grid_base));
     memset(&mi_params->tx_type_map[mi_grid_idx], 0,
            sb_size_mi * sizeof(*mi_params->tx_type_map));
+#if CONFIG_CROSS_CHROMA_TX
+    memset(&mi_params->cctx_type_map[mi_grid_idx], 0,
+           sb_size_mi * sizeof(*mi_params->cctx_type_map));
+#endif  // CONFIG_CROSS_CHROMA_TX
     if (cur_mi_row % mi_alloc_size_1d == 0) {
       memset(&mi_params->mi_alloc[alloc_mi_idx], 0,
              sb_size_alloc_mi * sizeof(*mi_params->mi_alloc));
diff --git a/av1/encoder/encodemb.c b/av1/encoder/encodemb.c
index 3cc29b1..a2453b5 100644
--- a/av1/encoder/encodemb.c
+++ b/av1/encoder/encodemb.c
@@ -77,6 +77,9 @@
 
 int av1_optimize_b(const struct AV1_COMP *cpi, MACROBLOCK *x, int plane,
                    int block, TX_SIZE tx_size, TX_TYPE tx_type,
+#if CONFIG_CROSS_CHROMA_TX
+                   CctxType cctx_type,
+#endif  // CONFIG_CROSS_CHROMA_TX
                    const TXB_CTX *const txb_ctx, int *rate_cost) {
   MACROBLOCKD *const xd = &x->e_mbd;
   struct macroblock_plane *const p = &x->plane[plane];
@@ -91,11 +94,17 @@
                                    x, block
 #endif  // CONFIG_CONTEXT_DERIVATION
     );
+#if CONFIG_CROSS_CHROMA_TX
+    *rate_cost += get_cctx_type_cost(x, xd, plane, tx_size, block, cctx_type);
+#endif  // CONFIG_CROSS_CHROMA_TX
     return eob;
   }
 
-  return av1_optimize_txb_new(cpi, x, plane, block, tx_size, tx_type, txb_ctx,
-                              rate_cost, cpi->oxcf.algo_cfg.sharpness);
+  return av1_optimize_txb_new(cpi, x, plane, block, tx_size, tx_type,
+#if CONFIG_CROSS_CHROMA_TX
+                              cctx_type,
+#endif  // CONFIG_CROSS_CHROMA_TX
+                              txb_ctx, rate_cost, cpi->oxcf.algo_cfg.sharpness);
 }
 
 // Hyper-parameters for dropout optimization, based on following logics.
@@ -274,33 +283,29 @@
   const int is_inter = is_inter_block(mbmi, xd->tree_type);
 #endif  // CONFIG_FORWARDSKIP
 #if CONFIG_CROSS_CHROMA_TX
-  if (is_inter_block(x->e_mbd.mi[0], x->e_mbd.tree_type)) {
-    switch (plane) {
-      case AOM_PLANE_Y:
+  if ((is_inter_block(mbmi, xd->tree_type) && CCTX_INTER) ||
+      (!is_inter_block(mbmi, xd->tree_type) && CCTX_INTRA)) {
+    // In the pipeline of cross-chroma transform, the forward transform for
+    // plane V is done earlier in plane U, followed by forward cross chroma
+    // transform, in order to obtain the quantized coefficients of the second
+    // channel.
+    if (plane != AOM_PLANE_V) {
 #if CONFIG_IST
-        av1_xform(x, AOM_PLANE_Y, block, blk_row, blk_col, plane_bsize,
-                  txfm_param, 0);
+      av1_xform(x, plane, block, blk_row, blk_col, plane_bsize, txfm_param, 0);
 #else
-        av1_xform(x, AOM_PLANE_Y, block, blk_row, blk_col, plane_bsize,
-                  txfm_param);
+      av1_xform(x, plane, block, blk_row, blk_col, plane_bsize, txfm_param);
 #endif
-        break;
-      case AOM_PLANE_U:
+    }
+    if (plane == AOM_PLANE_U) {
 #if CONFIG_IST
-        av1_xform(x, AOM_PLANE_U, block, blk_row, blk_col, plane_bsize,
-                  txfm_param, 0);
-        av1_xform(x, AOM_PLANE_V, block, blk_row, blk_col, plane_bsize,
-                  txfm_param, 0);
+      av1_xform(x, AOM_PLANE_V, block, blk_row, blk_col, plane_bsize,
+                txfm_param, 0);
 #else
-        av1_xform(x, AOM_PLANE_U, block, blk_row, blk_col, plane_bsize,
-                  txfm_param);
-        av1_xform(x, AOM_PLANE_V, block, blk_row, blk_col, plane_bsize,
-                  txfm_param);
+      av1_xform(x, AOM_PLANE_V, block, blk_row, blk_col, plane_bsize,
+                txfm_param);
 #endif
-        forward_cross_chroma_transform(x, block, txfm_param->tx_size);
-        // TODO(kslu): maybe skip av1_setup_xform for V
-        break;
-      case AOM_PLANE_V: break;
+      forward_cross_chroma_transform(x, block, txfm_param->tx_size,
+                                     txfm_param->cctx_type);
     }
   } else {
 #endif  // CONFIG_CROSS_CHROMA_TX
@@ -390,13 +395,14 @@
 }
 
 #if CONFIG_CROSS_CHROMA_TX
-void forward_cross_chroma_transform(MACROBLOCK *x, int block, TX_SIZE tx_size) {
+void forward_cross_chroma_transform(MACROBLOCK *x, int block, TX_SIZE tx_size,
+                                    CctxType cctx_type) {
   struct macroblock_plane *const p_u = &x->plane[AOM_PLANE_U];
   struct macroblock_plane *const p_v = &x->plane[AOM_PLANE_V];
   const int block_offset = BLOCK_OFFSET(block);
   tran_low_t *coeff_u = p_u->coeff + block_offset;
   tran_low_t *coeff_v = p_v->coeff + block_offset;
-  av1_fwd_cross_chroma_tx_block(coeff_u, coeff_v, tx_size);
+  av1_fwd_cross_chroma_tx_block(coeff_u, coeff_v, tx_size, cctx_type);
 }
 #endif  // CONFIG_CROSS_CHROMA_TX
 
@@ -452,7 +458,11 @@
 #if CONFIG_IST
                      int plane,
 #endif
-                     TX_SIZE tx_size, TX_TYPE tx_type, TxfmParam *txfm_param) {
+                     TX_SIZE tx_size, TX_TYPE tx_type,
+#if CONFIG_CROSS_CHROMA_TX
+                     CctxType cctx_type,
+#endif  // CONFIG_CROSS_CHROMA_TX
+                     TxfmParam *txfm_param) {
   MACROBLOCKD *const xd = &x->e_mbd;
   MB_MODE_INFO *const mbmi = xd->mi[0];
 
@@ -473,6 +483,9 @@
 #else
   txfm_param->tx_type = tx_type;
 #endif
+#if CONFIG_CROSS_CHROMA_TX
+  txfm_param->cctx_type = cctx_type;
+#endif  // CONFIG_CROSS_CHROMA_TX
   txfm_param->tx_size = tx_size;
   txfm_param->lossless = xd->lossless[mbmi->segment_id];
   txfm_param->tx_set_type =
@@ -528,7 +541,7 @@
   MB_MODE_INFO *mbmi = xd->mi[0];
   struct macroblock_plane *const p = &x->plane[plane];
   struct macroblockd_plane *const pd = &xd->plane[plane];
-#if CONFIG_IST || CONFIG_CROSS_CHROMA_TX
+#if CONFIG_IST || (CONFIG_CROSS_CHROMA_TX && CCTX_INTER)
   tran_low_t *dqcoeff = p->dqcoeff + BLOCK_OFFSET(block);
 #else
   tran_low_t *const dqcoeff = p->dqcoeff + BLOCK_OFFSET(block);
@@ -543,16 +556,30 @@
   a = &args->ta[blk_col];
   l = &args->tl[blk_row];
 
-  TX_TYPE tx_type = DCT_DCT;
+  TX_TYPE tx_type = av1_get_tx_type(xd, pd->plane_type, blk_row, blk_col,
+                                    tx_size, cm->features.reduced_tx_set_used);
+#if CONFIG_CROSS_CHROMA_TX
+  CctxType cctx_type = av1_get_cctx_type(xd, blk_row, blk_col);
+#endif  // CONFIG_CROSS_CHROMA_TX
+
   if (!is_blk_skip(x->txfm_search_info.blk_skip, plane,
                    blk_row * bw + blk_col) &&
+#if CONFIG_CROSS_CHROMA_TX && CCTX_INTER && CCTX_C1_NONZERO
+#if CCTX_C2_DROPPED
+      (plane < AOM_PLANE_V ||
+       ((cctx_type == CCTX_NONE || x->plane[AOM_PLANE_U].eobs[block]) &&
+        keep_chroma_c2(cctx_type))) &&
+#else
+      (plane < AOM_PLANE_V || cctx_type == CCTX_NONE ||
+       x->plane[AOM_PLANE_U].eobs[block]) &&
+#endif
+#endif  // CONFIG_CROSS_CHROMA_TX && CCTX_INTER && CCTX_C1_NONZERO
 #if CONFIG_SKIP_MODE_ENHANCEMENT
       !(mbmi->skip_mode == 1)) {
 #else
       !mbmi->skip_mode) {
 #endif  // CONFIG_SKIP_MODE_ENHANCEMENT
-    tx_type = av1_get_tx_type(xd, pd->plane_type, blk_row, blk_col, tx_size,
-                              cm->features.reduced_tx_set_used);
+
     TxfmParam txfm_param;
     QUANT_PARAM quant_param;
 #if CONFIG_FORWARDSKIP
@@ -576,7 +603,11 @@
 #if CONFIG_IST
                     plane,
 #endif
-                    tx_size, tx_type, &txfm_param);
+                    tx_size, tx_type,
+#if CONFIG_CROSS_CHROMA_TX
+                    cctx_type,
+#endif  // CONFIG_CROSS_CHROMA_TX
+                    &txfm_param);
     av1_setup_quant(tx_size, use_trellis, quant_idx,
                     cpi->oxcf.q_cfg.quant_b_adapt, &quant_param);
     av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, tx_type,
@@ -602,8 +633,11 @@
                   mbmi->fsc_mode[xd->tree_type == CHROMA_PART]
 #endif  // CONFIG_FORWARDSKIP
       );
-      av1_optimize_b(args->cpi, x, plane, block, tx_size, tx_type, &txb_ctx,
-                     &dummy_rate_cost);
+      av1_optimize_b(args->cpi, x, plane, block, tx_size, tx_type,
+#if CONFIG_CROSS_CHROMA_TX
+                     cctx_type,
+#endif  // CONFIG_CROSS_CHROMA_TX
+                     &txb_ctx, &dummy_rate_cost);
     }
     if (!quant_param.use_optimize_b && do_dropout
 #if CONFIG_FORWARDSKIP
@@ -613,46 +647,59 @@
       av1_dropout_qcoeff(x, plane, block, tx_size, tx_type,
                          cm->quant_params.base_qindex);
     }
+#if CONFIG_CROSS_CHROMA_TX && CCTX_C1_NONZERO
+    // Since eob can be updated here, make sure cctx_type is always CCTX_NONE
+    // when eob of U is 0.
+    // TODO(kslu) why cctx_type can be > CCTX_NONE when eob_u is 0?
+    if (plane == AOM_PLANE_U && p->eobs[block] == 0)
+      update_cctx_array(xd, blk_row, blk_col, tx_size, CCTX_NONE);
+#endif  // CONFIG_CROSS_CHROMA_TX && CCTX_C1_NONZERO
   } else {
+#if CONFIG_CROSS_CHROMA_TX && CCTX_C2_DROPPED
+    // Reset coeffs and dqcoeffs
+    if (plane == AOM_PLANE_V && !keep_chroma_c2(cctx_type))
+      av1_quantize_skip(av1_get_max_eob(tx_size),
+                        p->coeff + BLOCK_OFFSET(block), dqcoeff,
+                        &p->eobs[block]);
+#endif  // CONFIG_CROSS_CHROMA_TX && CCTX_C2_DROPPED
     p->eobs[block] = 0;
     p->txb_entropy_ctx[block] = 0;
   }
 
   av1_set_txb_context(x, plane, block, tx_size, a, l);
 
-#if CONFIG_CROSS_CHROMA_TX
+#if CONFIG_CROSS_CHROMA_TX && CCTX_INTER
   // In CONFIG_CROSS_CHROMA_TX, reconstruction for U plane relies on dqcoeffs of
   // V plane, so the below operations for U are performed together with V once
   // dqcoeffs of V are obtained.
-  if (is_inter_block(mbmi, xd->tree_type) && plane == AOM_PLANE_U) {
+  if (plane == AOM_PLANE_U) {
     if (p->eobs[block]) *(args->skip) = 0;
     return;
-  } else if (is_inter_block(mbmi, xd->tree_type) && plane == AOM_PLANE_V) {
-    struct macroblock_plane *const p_u = &x->plane[AOM_PLANE_U];
-    tran_low_t *dqcoeff_u = x->plane[AOM_PLANE_U].dqcoeff + BLOCK_OFFSET(block);
-    struct macroblockd_plane *const pd_u = &xd->plane[AOM_PLANE_U];
-    uint8_t *dst_u =
-        &pd_u->dst.buf[(blk_row * pd_u->dst.stride + blk_col) << MI_SIZE_LOG2];
-    av1_inv_cross_chroma_tx_block(dqcoeff_u, dqcoeff, tx_size);
+  }
+  struct macroblock_plane *const p_u = &x->plane[AOM_PLANE_U];
+  struct macroblockd_plane *const pd_u = &xd->plane[AOM_PLANE_U];
+  tran_low_t *dqcoeff_u = p_u->dqcoeff + BLOCK_OFFSET(block);
+  uint8_t *dst_u =
+      &pd_u->dst.buf[(blk_row * pd_u->dst.stride + blk_col) << MI_SIZE_LOG2];
+  int eob_u = p_u->eobs[block];
+  int eob_v = x->plane[AOM_PLANE_V].eobs[block];
+  if (plane == AOM_PLANE_V && (eob_u || eob_v)) {
+    av1_inv_cross_chroma_tx_block(dqcoeff_u, dqcoeff, tx_size, cctx_type);
     av1_inverse_transform_block(xd, dqcoeff_u, AOM_PLANE_U, tx_type, tx_size,
-                                dst_u, pd_u->dst.stride,
-                                AOMMAX(p_u->eobs[block], p->eobs[block]),
+                                dst_u, pd_u->dst.stride, AOMMAX(eob_u, eob_v),
                                 cm->features.reduced_tx_set_used);
   }
 
   // TODO(kslu): keep track of transform domain eobs for U and V
-  if (p->eobs[block] || (plane && (x->plane[AOM_PLANE_U].eobs[block] ||
-                                   x->plane[AOM_PLANE_V].eobs[block]))) {
+  if (p->eobs[block] || (plane && (eob_u || eob_v))) {
 #else
   if (p->eobs[block]) {
 #endif  // CONFIG_CROSS_CHROMA_TX
     *(args->skip) = 0;
     av1_inverse_transform_block(
         xd, dqcoeff, plane, tx_type, tx_size, dst, pd->dst.stride,
-#if CONFIG_CROSS_CHROMA_TX
-        (plane == 0) ? p->eobs[block]
-                     : AOMMAX(x->plane[AOM_PLANE_U].eobs[block],
-                              x->plane[AOM_PLANE_V].eobs[block]),
+#if CONFIG_CROSS_CHROMA_TX && CCTX_INTER
+        (plane == 0) ? p->eobs[block] : AOMMAX(eob_u, eob_v),
 #else
         p->eobs[block],
 #endif
@@ -688,17 +735,11 @@
     int blk_h = block_size_high[bsize];
     mi_to_pixel_loc(&pixel_c, &pixel_r, xd->mi_col, xd->mi_row, blk_col,
                     blk_row, pd->subsampling_x, pd->subsampling_y);
-#if CONFIG_CROSS_CHROMA_TX
-    if (plane == AOM_PLANE_V) {
-      struct macroblockd_plane *const pd_u = &xd->plane[AOM_PLANE_U];
-      uint8_t *dst_u =
-          &pd_u->dst
-               .buf[(blk_row * pd_u->dst.stride + blk_col) << MI_SIZE_LOG2];
+#if CONFIG_CROSS_CHROMA_TX && CCTX_INTER
+    if (plane == AOM_PLANE_V)
       mismatch_record_block_tx(dst_u, pd_u->dst.stride,
                                cm->current_frame.order_hint, AOM_PLANE_U,
-                               pixel_c, pixel_r, blk_w, blk_h,
-                               xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH);
-    }
+                               pixel_c, pixel_r, blk_w, blk_h);
 #endif  // CONFIG_CROSS_CHROMA_TX
     mismatch_record_block_tx(dst, pd->dst.stride, cm->current_frame.order_hint,
                              plane, pixel_c, pixel_r, blk_w, blk_h);
@@ -849,7 +890,11 @@
 #if CONFIG_IST
                   plane,
 #endif
-                  tx_size, DCT_DCT, &txfm_param);
+                  tx_size, DCT_DCT,
+#if CONFIG_CROSS_CHROMA_TX
+                  CCTX_NONE,
+#endif  // CONFIG_CROSS_CHROMA_TX
+                  &txfm_param);
   av1_setup_quant(tx_size, 0, AV1_XFORM_QUANT_B, cpi->oxcf.q_cfg.quant_b_adapt,
                   &quant_param);
   av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, DCT_DCT,
@@ -897,7 +942,7 @@
     cpi,  x,    &ctx,    &mbmi->skip_txfm[xd->tree_type == CHROMA_PART],
     NULL, NULL, dry_run, cpi->optimize_seg_arr[mbmi->segment_id]
   };
-#if CONFIG_CROSS_CHROMA_TX
+#if CONFIG_CROSS_CHROMA_TX && CCTX_INTER
   // Subtract first, so both U and V residues will be available when U component
   // is being transformed and quantized.
   for (int plane = plane_start; plane < plane_end; ++plane) {
@@ -926,7 +971,7 @@
     const int step =
         tx_size_wide_unit[max_tx_size] * tx_size_high_unit[max_tx_size];
     av1_get_entropy_contexts(plane_bsize, pd, ctx.ta[plane], ctx.tl[plane]);
-#if !CONFIG_CROSS_CHROMA_TX
+#if !(CONFIG_CROSS_CHROMA_TX && CCTX_INTER)
     av1_subtract_plane(x, plane_bsize, plane);
 #endif  // !CONFIG_CROSS_CHROMA_TX
 
@@ -1001,6 +1046,7 @@
   av1_predict_intra_block_facade(cm, xd, plane, blk_col, blk_row, tx_size);
 
   TX_TYPE tx_type = DCT_DCT;
+
   const int bw = mi_size_wide[plane_bsize];
 #if DEBUG_EXTQUANT
   if (args->dry_run == OUTPUT_ENABLED) {
@@ -1051,7 +1097,11 @@
 #if CONFIG_IST
                     plane,
 #endif
-                    tx_size, tx_type, &txfm_param);
+                    tx_size, tx_type,
+#if CONFIG_CROSS_CHROMA_TX
+                    CCTX_NONE,
+#endif  // CONFIG_CROSS_CHROMA_TX
+                    &txfm_param);
     av1_setup_quant(tx_size, use_trellis, quant_idx,
                     cpi->oxcf.q_cfg.quant_b_adapt, &quant_param);
     av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, tx_type,
@@ -1094,8 +1144,11 @@
                   mbmi->fsc_mode[xd->tree_type == CHROMA_PART]
 #endif  // CONFIG_FORWARDSKIP
       );
-      av1_optimize_b(args->cpi, x, plane, block, tx_size, tx_type, &txb_ctx,
-                     &dummy_rate_cost);
+      av1_optimize_b(args->cpi, x, plane, block, tx_size, tx_type,
+#if CONFIG_CROSS_CHROMA_TX
+                     CCTX_NONE,
+#endif  // CONFIG_CROSS_CHROMA_TX
+                     &txb_ctx, &dummy_rate_cost);
     }
     if (do_dropout
 #if CONFIG_FORWARDSKIP
@@ -1108,7 +1161,6 @@
   }
 
   if (*eob) {
-    // TODO(kslu) apply inv cctx for u plane once it is needed for intra
     av1_inverse_transform_block(xd, dqcoeff, plane, tx_type, tx_size, dst,
                                 dst_stride, *eob,
                                 cm->features.reduced_tx_set_used);
@@ -1161,3 +1213,180 @@
   av1_foreach_transformed_block_in_plane(
       xd, plane_bsize, plane, encode_block_intra_and_set_context, &arg);
 }
+
+#if CONFIG_CROSS_CHROMA_TX && CCTX_INTRA
+void av1_encode_block_intra_joint_uv(int block, int blk_row, int blk_col,
+                                     BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
+                                     void *arg) {
+  struct encode_b_args *const args = arg;
+  const AV1_COMP *const cpi = args->cpi;
+  const AV1_COMMON *const cm = &cpi->common;
+  MACROBLOCK *const x = args->x;
+  MACROBLOCKD *const xd = &x->e_mbd;
+  struct macroblock_plane *const p_u = &x->plane[AOM_PLANE_U];
+  struct macroblock_plane *const p_v = &x->plane[AOM_PLANE_V];
+  struct macroblockd_plane *const pd_u = &xd->plane[AOM_PLANE_U];
+  struct macroblockd_plane *const pd_v = &xd->plane[AOM_PLANE_V];
+  tran_low_t *dqcoeff_u = p_u->dqcoeff + BLOCK_OFFSET(block);
+  tran_low_t *dqcoeff_v = p_v->dqcoeff + BLOCK_OFFSET(block);
+  uint16_t *eob_u = &p_u->eobs[block];
+  uint16_t *eob_v = &p_v->eobs[block];
+  const int dst_stride = pd_u->dst.stride;
+  uint8_t *dst_u =
+      &pd_u->dst.buf[(blk_row * dst_stride + blk_col) << MI_SIZE_LOG2];
+  uint8_t *dst_v =
+      &pd_v->dst.buf[(blk_row * dst_stride + blk_col) << MI_SIZE_LOG2];
+  int dummy_rate_cost = 0;
+
+  av1_predict_intra_block_facade(cm, xd, AOM_PLANE_U, blk_col, blk_row,
+                                 tx_size);
+  av1_predict_intra_block_facade(cm, xd, AOM_PLANE_V, blk_col, blk_row,
+                                 tx_size);
+
+  TX_TYPE tx_type = av1_get_tx_type(xd, PLANE_TYPE_UV, blk_row, blk_col,
+                                    tx_size, cm->features.reduced_tx_set_used);
+  CctxType cctx_type = av1_get_cctx_type(xd, blk_row, blk_col);
+
+  av1_subtract_txb(x, AOM_PLANE_U, plane_bsize, blk_col, blk_row, tx_size);
+  av1_subtract_txb(x, AOM_PLANE_V, plane_bsize, blk_col, blk_row, tx_size);
+
+  TxfmParam txfm_param;
+  QUANT_PARAM quant_param;
+  const int use_trellis =
+      is_trellis_used(args->enable_optimize_b, args->dry_run);
+  int quant_idx;
+  if (use_trellis)
+    quant_idx = AV1_XFORM_QUANT_FP;
+  else
+    quant_idx = USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B : AV1_XFORM_QUANT_FP;
+
+  av1_setup_xform(cm, x,
+#if CONFIG_IST
+                  AOM_PLANE_U,
+#endif
+                  tx_size, tx_type, cctx_type, &txfm_param);
+  av1_setup_quant(tx_size, use_trellis, quant_idx,
+                  cpi->oxcf.q_cfg.quant_b_adapt, &quant_param);
+  // Whether trellis or dropout optimization is required for key frames and
+  // intra frames.
+  const bool do_trellis = (frame_is_intra_only(cm) &&
+                           (KEY_BLOCK_OPT_TYPE == TRELLIS_OPT ||
+                            KEY_BLOCK_OPT_TYPE == TRELLIS_DROPOUT_OPT)) ||
+                          (!frame_is_intra_only(cm) &&
+                           (INTRA_BLOCK_OPT_TYPE == TRELLIS_OPT ||
+                            INTRA_BLOCK_OPT_TYPE == TRELLIS_DROPOUT_OPT));
+  const bool do_dropout = (frame_is_intra_only(cm) &&
+                           (KEY_BLOCK_OPT_TYPE == DROPOUT_OPT ||
+                            KEY_BLOCK_OPT_TYPE == TRELLIS_DROPOUT_OPT)) ||
+                          (!frame_is_intra_only(cm) &&
+                           (INTRA_BLOCK_OPT_TYPE == DROPOUT_OPT ||
+                            INTRA_BLOCK_OPT_TYPE == TRELLIS_DROPOUT_OPT));
+
+  for (int plane = AOM_PLANE_U; plane <= AOM_PLANE_V; plane++) {
+#if CCTX_C1_NONZERO
+#if CCTX_C2_DROPPED
+    if (plane == AOM_PLANE_V && (!keep_chroma_c2(cctx_type) ||
+                                 (*eob_u == 0 && cctx_type > CCTX_NONE))) {
+#else
+    if (plane == AOM_PLANE_V && *eob_u == 0 && cctx_type > CCTX_NONE) {
+#endif
+      // Since eob can be updated here, make sure cctx_type is always CCTX_NONE
+      // when eob of U is 0.
+      if (*eob_u == 0 && cctx_type > CCTX_NONE)
+        update_cctx_array(xd, blk_row, blk_col, tx_size, CCTX_NONE);
+      av1_quantize_skip(av1_get_max_eob(tx_size),
+                        p_v->coeff + BLOCK_OFFSET(block), dqcoeff_v, eob_v);
+      break;
+    }
+#endif
+    av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, tx_type,
+                      &quant_param);
+    av1_xform_quant(
+#if CONFIG_FORWARDSKIP
+        cm,
+#endif  // CONFIG_FORWARDSKIP
+        x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
+        &quant_param);
+    if (quant_param.use_optimize_b && do_trellis) {
+      const ENTROPY_CONTEXT *a =
+          &args->ta[blk_col + (plane - AOM_PLANE_U) * MAX_MIB_SIZE];
+      const ENTROPY_CONTEXT *l =
+          &args->tl[blk_row + (plane - AOM_PLANE_U) * MAX_MIB_SIZE];
+      TXB_CTX txb_ctx;
+      get_txb_ctx(plane_bsize, tx_size, plane, a, l, &txb_ctx
+#if CONFIG_FORWARDSKIP
+                  ,
+                  xd->mi[0]->fsc_mode[xd->tree_type == CHROMA_PART]
+#endif  // CONFIG_FORWARDSKIP
+      );
+      av1_optimize_b(args->cpi, x, plane, block, tx_size, tx_type, cctx_type,
+                     &txb_ctx, &dummy_rate_cost);
+    }
+    if (do_dropout) {
+      av1_dropout_qcoeff(x, plane, block, tx_size, tx_type,
+                         cm->quant_params.base_qindex);
+    }
+  }
+
+  av1_inv_cross_chroma_tx_block(dqcoeff_u, dqcoeff_v, tx_size, cctx_type);
+  if (*eob_u || *eob_v) {
+    // TODO(kslu) keep track of transform domain eobs for U and V
+    av1_inverse_transform_block(xd, dqcoeff_u, AOM_PLANE_U, tx_type, tx_size,
+                                dst_u, dst_stride, AOMMAX(*eob_u, *eob_v),
+                                cm->features.reduced_tx_set_used);
+    av1_inverse_transform_block(xd, dqcoeff_v, AOM_PLANE_V, tx_type, tx_size,
+                                dst_v, dst_stride, AOMMAX(*eob_u, *eob_v),
+                                cm->features.reduced_tx_set_used);
+  }
+
+  // For intra mode, skipped blocks are so rare that transmitting skip=1 is
+  // very expensive.
+  *(args->skip) = 0;
+}
+
+static void encode_block_intra_and_set_context_joint_uv(
+    int plane, int block, int blk_row, int blk_col, BLOCK_SIZE plane_bsize,
+    TX_SIZE tx_size, void *arg) {
+  (void)plane;
+  av1_encode_block_intra_joint_uv(block, blk_row, blk_col, plane_bsize, tx_size,
+                                  arg);
+
+  struct encode_b_args *const args = arg;
+  MACROBLOCK *x = args->x;
+  ENTROPY_CONTEXT *au = &args->ta[blk_col];
+  ENTROPY_CONTEXT *lu = &args->tl[blk_row];
+  ENTROPY_CONTEXT *av = &args->ta[MAX_MIB_SIZE + blk_col];
+  ENTROPY_CONTEXT *lv = &args->tl[MAX_MIB_SIZE + blk_row];
+  av1_set_txb_context(x, AOM_PLANE_U, block, tx_size, au, lu);
+  av1_set_txb_context(x, AOM_PLANE_V, block, tx_size, av, lv);
+}
+
+void av1_encode_intra_block_joint_uv(const struct AV1_COMP *cpi, MACROBLOCK *x,
+                                     BLOCK_SIZE bsize, RUN_TYPE dry_run,
+                                     TRELLIS_OPT_TYPE enable_optimize_b) {
+  assert(bsize < BLOCK_SIZES_ALL);
+  const MACROBLOCKD *const xd = &x->e_mbd;
+  if (!xd->is_chroma_ref) return;
+
+  const struct macroblockd_plane *const pd_u = &xd->plane[AOM_PLANE_U];
+  const struct macroblockd_plane *const pd_v = &xd->plane[AOM_PLANE_V];
+  const int ss_x = pd_u->subsampling_x;
+  const int ss_y = pd_u->subsampling_y;
+  assert(ss_x == pd_v->subsampling_x && ss_y == pd_v->subsampling_y);
+  ENTROPY_CONTEXT ta[MAX_MIB_SIZE * 2] = { 0 };
+  ENTROPY_CONTEXT tl[MAX_MIB_SIZE * 2] = { 0 };
+  struct encode_b_args arg = {
+    cpi, x,  NULL,    &(xd->mi[0]->skip_txfm[xd->tree_type == CHROMA_PART]),
+    ta,  tl, dry_run, enable_optimize_b
+  };
+  const BLOCK_SIZE plane_bsize = get_plane_block_size(bsize, ss_x, ss_y);
+  if (enable_optimize_b) {
+    av1_get_entropy_contexts(plane_bsize, pd_u, ta, tl);
+    av1_get_entropy_contexts(plane_bsize, pd_v, &ta[MAX_MIB_SIZE],
+                             &tl[MAX_MIB_SIZE]);
+  }
+  av1_foreach_transformed_block_in_plane(
+      xd, plane_bsize, AOM_PLANE_U, encode_block_intra_and_set_context_joint_uv,
+      &arg);
+}
+#endif  // CONFIG_CROSS_CHROMA_TX && CCTX_INTRA
diff --git a/av1/encoder/encodemb.h b/av1/encoder/encodemb.h
index 47f0b36..ade30ee 100644
--- a/av1/encoder/encodemb.h
+++ b/av1/encoder/encodemb.h
@@ -78,7 +78,11 @@
 #if CONFIG_IST
                      int plane,
 #endif
-                     TX_SIZE tx_size, TX_TYPE tx_type, TxfmParam *txfm_param);
+                     TX_SIZE tx_size, TX_TYPE tx_type,
+#if CONFIG_CROSS_CHROMA_TX
+                     CctxType cctx_Type,
+#endif  // CONFIG_CROSS_CHROMA_TX
+                     TxfmParam *txfm_param);
 void av1_setup_quant(TX_SIZE tx_size, int use_optimize_b, int xform_quant_idx,
                      int use_quant_b_adapt, QUANT_PARAM *qparam);
 
@@ -110,7 +114,8 @@
 );
 
 #if CONFIG_CROSS_CHROMA_TX
-void forward_cross_chroma_transform(MACROBLOCK *x, int block, TX_SIZE tx_size);
+void forward_cross_chroma_transform(MACROBLOCK *x, int block, TX_SIZE tx_size,
+                                    CctxType cctx_type);
 #endif  // CONFIG_CROSS_CHROMA_TX
 
 void av1_quant(MACROBLOCK *x, int plane, int block, TxfmParam *txfm_param,
@@ -118,6 +123,9 @@
 
 int av1_optimize_b(const struct AV1_COMP *cpi, MACROBLOCK *mb, int plane,
                    int block, TX_SIZE tx_size, TX_TYPE tx_type,
+#if CONFIG_CROSS_CHROMA_TX
+                   CctxType cctx_type,
+#endif  // CONFIG_CROSS_CHROMA_TX
                    const TXB_CTX *const txb_ctx, int *rate_cost);
 
 // This function can be used as (i) a further optimization to reduce the
@@ -170,7 +178,11 @@
 void av1_encode_intra_block_plane(const struct AV1_COMP *cpi, MACROBLOCK *x,
                                   BLOCK_SIZE bsize, int plane, RUN_TYPE dry_run,
                                   TRELLIS_OPT_TYPE enable_optimize_b);
-
+#if CONFIG_CROSS_CHROMA_TX
+void av1_encode_intra_block_joint_uv(const struct AV1_COMP *cpi, MACROBLOCK *x,
+                                     BLOCK_SIZE bsize, RUN_TYPE dry_run,
+                                     TRELLIS_OPT_TYPE enable_optimize_b);
+#endif  // CONFIG_CROSS_CHROMA_TX
 static INLINE int is_trellis_used(TRELLIS_OPT_TYPE optimize_b,
                                   RUN_TYPE dry_run) {
   if (optimize_b == NO_TRELLIS_OPT) return false;
diff --git a/av1/encoder/encoder.h b/av1/encoder/encoder.h
index 3072f2f..6d8cd79 100644
--- a/av1/encoder/encoder.h
+++ b/av1/encoder/encoder.h
@@ -1297,6 +1297,9 @@
   unsigned int intra_ext_tx[EXT_TX_SETS_INTRA][EXT_TX_SIZES][INTRA_MODES]
                            [TX_TYPES];
 #endif  // CONFIG_DDT_INTER
+#if CONFIG_CROSS_CHROMA_TX
+  unsigned int cctx_type[EXT_TX_SIZES][CCTX_TYPES];
+#endif  // CONFIG_CROSS_CHROMA_TX
   unsigned int filter_intra_mode[FILTER_INTRA_MODES];
   unsigned int filter_intra[BLOCK_SIZES_ALL][2];
   unsigned int switchable_restore[RESTORE_SWITCHABLE_TYPES];
diff --git a/av1/encoder/encoder_utils.h b/av1/encoder/encoder_utils.h
index e84d0eb..f9b2f96 100644
--- a/av1/encoder/encoder_utils.h
+++ b/av1/encoder/encoder_utils.h
@@ -87,6 +87,10 @@
   mi_params->mi_alloc_size = 0;
   aom_free(mi_params->tx_type_map);
   mi_params->tx_type_map = NULL;
+#if CONFIG_CROSS_CHROMA_TX
+  aom_free(mi_params->cctx_type_map);
+  mi_params->cctx_type_map = NULL;
+#endif  // CONFIG_CROSS_CHROMA_TX
 }
 
 static AOM_INLINE void enc_set_mb_mi(CommonModeInfoParams *mi_params, int width,
@@ -113,6 +117,10 @@
          mi_grid_size * sizeof(*mi_params->mi_grid_base));
   memset(mi_params->tx_type_map, 0,
          mi_grid_size * sizeof(*mi_params->tx_type_map));
+#if CONFIG_CROSS_CHROMA_TX
+  memset(mi_params->cctx_type_map, 0,
+         mi_grid_size * sizeof(*mi_params->cctx_type_map));
+#endif  // CONFIG_CROSS_CHROMA_TX
 }
 
 static AOM_INLINE void init_buffer_indices(
diff --git a/av1/encoder/encodetxb.c b/av1/encoder/encodetxb.c
index 85eebe2..4ca8a19 100644
--- a/av1/encoder/encodetxb.c
+++ b/av1/encoder/encodetxb.c
@@ -435,6 +435,13 @@
       av1_get_tx_type(xd, plane_type, blk_row, blk_col, tx_size,
                       cm->features.reduced_tx_set_used);
 
+#if CONFIG_CROSS_CHROMA_TX && CCTX_C2_DROPPED
+  if (plane == AOM_PLANE_V) {
+    CctxType cctx_type = av1_get_cctx_type(xd, blk_row, blk_col);
+    if (!keep_chroma_c2(cctx_type)) return 0;
+  }
+#endif  // CONFIG_CROSS_CHROMA_TX && CCTX_C2_DROPPED
+
 #if CONFIG_CONTEXT_DERIVATION
   if (plane == AOM_PLANE_U) {
     xd->eob_u_flag = eob ? 1 : 0;
@@ -448,7 +455,29 @@
 #else
   aom_write_symbol(w, eob == 0, ec_ctx->txb_skip_cdf[txs_ctx][txb_skip_ctx], 2);
 #endif  // CONFIG_CONTEXT_DERIVATION
+
+#if CONFIG_CROSS_CHROMA_TX
+#if CCTX_C1_NONZERO
+  if (plane == AOM_PLANE_U && eob > 0) {
+    CctxType cctx_type = av1_get_cctx_type(xd, blk_row, blk_col);
+    av1_write_cctx_type(cm, xd, cctx_type, tx_size, w);
+  }
+#else
+  // tx_type is signaled with Y plane if eob > 0. cctx_type is signaled with V
+  // plane if either of eob_u and eob_v is > 0.
+  if (plane == AOM_PLANE_V) {
+    const uint16_t *eob_txb_u = cb_coef_buff->eobs[AOM_PLANE_U] + txb_offset;
+    const uint16_t eob_u = eob_txb_u[block];
+    if (eob > 0 || eob_u > 0) {
+      const CctxType cctx_type = av1_get_cctx_type(xd, blk_row, blk_col);
+      av1_write_cctx_type(cm, xd, cctx_type, tx_size, w);
+    }
+  }
+#endif  // CCTX_C1_NONZERO
+#endif  // CONFIG_CROSS_CHROMA_TX
+
   if (eob == 0) return 0;
+
   if (plane == 0) {  // Only y plane's tx_type is transmitted
     av1_write_tx_type(cm, xd, tx_type, tx_size, w);
   }
@@ -550,6 +579,12 @@
   FRAME_CONTEXT *ec_ctx = xd->tile_ctx;
 
 #if !CONFIG_FORWARDSKIP
+#if CONFIG_CROSS_CHROMA_TX && CCTX_C2_DROPPED
+  if (plane == AOM_PLANE_V) {
+    CctxType cctx_type = av1_get_cctx_type(xd, blk_row, blk_col);
+    if (!keep_chroma_c2(cctx_type)) return 0;
+  }
+#endif  // CONFIG_CROSS_CHROMA_TX && CCTX_C2_DROPPED
 #if CONFIG_CONTEXT_DERIVATION
   if (plane == AOM_PLANE_U) {
     xd->eob_u_flag = eob ? 1 : 0;
@@ -563,7 +598,27 @@
 #else
   aom_write_symbol(w, eob == 0, ec_ctx->txb_skip_cdf[txs_ctx][txb_skip_ctx], 2);
 #endif  // CONFIG_CONTEXT_DERIVATION
+
+#if CONFIG_CROSS_CHROMA_TX
+#if CCTX_C1_NONZERO
+  if (plane == AOM_PLANE_U && eob > 0) {
+    CctxType cctx_type = av1_get_cctx_type(xd, blk_row, blk_col);
+    av1_write_cctx_type(cm, xd, cctx_type, tx_size, w);
+  }
+#else
+  // CCTX type is transmitted with V plane
+  if (plane == AOM_PLANE_V) {
+    const uint16_t *eob_txb_u = cb_coef_buff->eobs[AOM_PLANE_U] + txb_offset;
+    const uint16_t eob_u = eob_txb_u[block];
+    if (eob > 0 || eob_u > 0) {
+      const CctxType cctx_type = av1_get_cctx_type(xd, blk_row, blk_col);
+      av1_write_cctx_type(cm, xd, cctx_type, tx_size, w);
+    }
+  }
+#endif  // CCTX_C1_NONZERO
+#endif  // CONFIG_CROSS_CHROMA_TX
 #endif  // CONFIG_FORWARDSKIP
+
   if (eob == 0) return;
 
   const PLANE_TYPE plane_type = get_plane_type(plane);
@@ -823,6 +878,26 @@
   }
 }
 
+#if CONFIG_CROSS_CHROMA_TX
+int get_cctx_type_cost(const MACROBLOCK *x, const MACROBLOCKD *xd, int plane,
+                       TX_SIZE tx_size, int block, CctxType cctx_type) {
+  const int is_inter = is_inter_block(xd->mi[0], xd->tree_type);
+  const TX_SIZE square_tx_size = txsize_sqr_map[tx_size];
+#if CCTX_C1_NONZERO
+  (void)block;
+  if (plane == AOM_PLANE_U && x->plane[plane].eobs[block] &&
+#else
+  if (plane == AOM_PLANE_V &&
+      (x->plane[AOM_PLANE_U].eobs[block] ||
+       x->plane[AOM_PLANE_V].eobs[block]) &&
+#endif  // CCTX_C1_NONZERO
+      ((is_inter && CCTX_INTER) || (!is_inter && CCTX_INTRA)))
+    return x->mode_costs.cctx_type_cost[square_tx_size][cctx_type];
+  else
+    return 0;
+}
+#endif  // CONFIG_CROSS_CHROMA_TX
+
 // TODO(angiebird): use this function whenever it's possible
 static int get_tx_type_cost(const MACROBLOCK *x, const MACROBLOCKD *xd,
                             int plane, TX_SIZE tx_size, TX_TYPE tx_type,
@@ -832,11 +907,11 @@
                             int eob
 #endif
 ) {
-  if (plane > 0) return 0;
-
+  const MB_MODE_INFO *mbmi = xd->mi[0];
   const TX_SIZE square_tx_size = txsize_sqr_map[tx_size];
 
-  const MB_MODE_INFO *mbmi = xd->mi[0];
+  if (plane > 0) return 0;
+
 #if CONFIG_FORWARDSKIP
   if (mbmi->fsc_mode[xd->tree_type == CHROMA_PART] &&
       !is_inter_block(mbmi, xd->tree_type) && plane == PLANE_TYPE_Y) {
@@ -953,7 +1028,11 @@
     const TX_SIZE tx_size, const TXB_CTX *const txb_ctx,
     const struct macroblock_plane *p, const int eob,
     const LV_MAP_COEFF_COST *const coeff_costs, const MACROBLOCKD *const xd,
-    const TX_TYPE tx_type, int reduced_tx_set_used) {
+    const TX_TYPE tx_type,
+#if CONFIG_CROSS_CHROMA_TX
+    const CctxType cctx_type,
+#endif  // CONFIG_CROSS_CHROMA_TX
+    int reduced_tx_set_used) {
   const tran_low_t *const qcoeff = p->qcoeff + BLOCK_OFFSET(block);
   const int txb_skip_ctx = txb_ctx->txb_skip_ctx;
   const int bwl = get_txb_bwl(tx_size);
@@ -973,6 +1052,9 @@
                            eob
 #endif  // CONFIG_IST
   );
+#if CONFIG_CROSS_CHROMA_TX
+  cost += get_cctx_type_cost(x, xd, plane, tx_size, block, cctx_type);
+#endif  // CONFIG_CROSS_CHROMA_TX
   DECLARE_ALIGNED(16, int8_t, coeff_contexts[MAX_TX_SQUARE]);
   av1_get_nz_map_contexts_skip(levels, scan, eob, tx_size, coeff_contexts);
   const int(*lps_cost)[COEFF_BASE_RANGE + 1 + COEFF_BASE_RANGE + 1] =
@@ -1010,8 +1092,11 @@
     const TX_SIZE tx_size, const TXB_CTX *const txb_ctx,
     const struct macroblock_plane *p, const int eob,
     const PLANE_TYPE plane_type, const LV_MAP_COEFF_COST *const coeff_costs,
-    const MACROBLOCKD *const xd, const TX_TYPE tx_type, const TX_CLASS tx_class,
-    int reduced_tx_set_used) {
+    const MACROBLOCKD *const xd, const TX_TYPE tx_type,
+#if CONFIG_CROSS_CHROMA_TX
+    const CctxType cctx_type,
+#endif  // CONFIG_CROSS_CHROMA_TX
+    const TX_CLASS tx_class, int reduced_tx_set_used) {
   const tran_low_t *const qcoeff = p->qcoeff + BLOCK_OFFSET(block);
 #if CONFIG_CONTEXT_DERIVATION
   const struct macroblock_plane *pu = &x->plane[AOM_PLANE_U];
@@ -1052,6 +1137,9 @@
                            eob
 #endif
   );
+#if CONFIG_CROSS_CHROMA_TX
+  cost += get_cctx_type_cost(x, xd, plane, tx_size, block, cctx_type);
+#endif  // CONFIG_CROSS_CHROMA_TX
 
   cost += get_eob_cost(eob, eob_costs, coeff_costs, tx_class);
 
@@ -1167,8 +1255,11 @@
     const MACROBLOCK *x, const int plane, const int block,
     const TX_SIZE tx_size, const TXB_CTX *const txb_ctx, const int eob,
     const PLANE_TYPE plane_type, const LV_MAP_COEFF_COST *const coeff_costs,
-    const MACROBLOCKD *const xd, const TX_TYPE tx_type, const TX_CLASS tx_class,
-    int reduced_tx_set_used) {
+    const MACROBLOCKD *const xd, const TX_TYPE tx_type,
+#if CONFIG_CROSS_CHROMA_TX
+    const CctxType cctx_type,
+#endif  // CONFIG_CROSS_CHROMA_TX
+    const TX_CLASS tx_class, int reduced_tx_set_used) {
 #if CONFIG_CONTEXT_DERIVATION
   int txb_skip_ctx = txb_ctx->txb_skip_ctx;
   if (plane == AOM_PLANE_V) {
@@ -1199,6 +1290,9 @@
                            eob
 #endif  // CONFIG_IST
   );
+#if CONFIG_CROSS_CHROMA_TX
+  cost += get_cctx_type_cost(x, xd, plane, tx_size, block, cctx_type);
+#endif  // CONFIG_CROSS_CHROMA_TX
 
 #if !CONFIG_FORWARDSKIP
   cost += get_eob_cost(eob, eob_costs, coeff_costs, tx_class);
@@ -1300,6 +1394,9 @@
 int av1_cost_coeffs_txb(const MACROBLOCK *x, const int plane, const int block,
 #endif  // CONFIG_FORWARDSKIP
                         const TX_SIZE tx_size, const TX_TYPE tx_type,
+#if CONFIG_CROSS_CHROMA_TX
+                        const CctxType cctx_type,
+#endif  // CONFIG_CROSS_CHROMA_TX
                         const TXB_CTX *const txb_ctx, int reduced_tx_set_used) {
   const struct macroblock_plane *p = &x->plane[plane];
   const int eob = p->eobs[block];
@@ -1307,22 +1404,30 @@
   const PLANE_TYPE plane_type = get_plane_type(plane);
   const LV_MAP_COEFF_COST *const coeff_costs =
       &x->coeff_costs.coeff_costs[txs_ctx][plane_type];
+  const MACROBLOCKD *const xd = &x->e_mbd;
   if (eob == 0) {
+#if CONFIG_CROSS_CHROMA_TX && CCTX_C2_DROPPED
+    if (plane == AOM_PLANE_V && !keep_chroma_c2(cctx_type)) return 0;
+#endif  // CONFIG_CROSS_CHROMA_TX && CCTX_C2_DROPPED
 #if CONFIG_CONTEXT_DERIVATION
     int txb_skip_ctx = txb_ctx->txb_skip_ctx;
+    int skip_cost = 0;
     if (plane == AOM_PLANE_Y || plane == AOM_PLANE_U) {
-      return coeff_costs->txb_skip_cost[txb_skip_ctx][1];
+      skip_cost += coeff_costs->txb_skip_cost[txb_skip_ctx][1];
     } else {
       txb_skip_ctx +=
           (x->plane[AOM_PLANE_U].eobs[block] ? V_TXB_SKIP_CONTEXT_OFFSET : 0);
-      return coeff_costs->v_txb_skip_cost[txb_skip_ctx][1];
+      skip_cost += coeff_costs->v_txb_skip_cost[txb_skip_ctx][1];
     }
 #else
-    return coeff_costs->txb_skip_cost[txb_ctx->txb_skip_ctx][1];
+    skip_cost += coeff_costs->txb_skip_cost[txb_ctx->txb_skip_ctx][1];
 #endif  // CONFIG_CONTEXT_DERIVATION
+#if CONFIG_CROSS_CHROMA_TX
+    skip_cost += get_cctx_type_cost(x, xd, plane, tx_size, block, cctx_type);
+#endif  // CONFIG_CROSS_CHROMA_TX
+    return skip_cost;
   }
 
-  const MACROBLOCKD *const xd = &x->e_mbd;
 #if CONFIG_IST
   const TX_CLASS tx_class = tx_type_to_class[get_primary_tx_type(tx_type)];
 #else
@@ -1340,15 +1445,24 @@
       use_inter_fsc(cm, plane, tx_type, is_inter_block(mbmi, xd->tree_type))) {
     return warehouse_efficients_txb_skip(x, plane, block, tx_size, txb_ctx, p,
                                          eob, coeff_costs, xd, tx_type,
+#if CONFIG_CROSS_CHROMA_TX
+                                         cctx_type,
+#endif  // CONFIG_CROSS_CHROMA_TX
                                          reduced_tx_set_used);
   } else {
     return warehouse_efficients_txb(x, plane, block, tx_size, txb_ctx, p, eob,
                                     plane_type, coeff_costs, xd, tx_type,
+#if CONFIG_CROSS_CHROMA_TX
+                                    cctx_type,
+#endif  // CONFIG_CROSS_CHROMA_TX
                                     tx_class, reduced_tx_set_used);
   }
 #else
   return warehouse_efficients_txb(x, plane, block, tx_size, txb_ctx, p, eob,
                                   plane_type, coeff_costs, xd, tx_type,
+#if CONFIG_CROSS_CHROMA_TX
+                                  cctx_type,
+#endif  // CONFIG_CROSS_CHROMA_TX
                                   tx_class, reduced_tx_set_used);
 #endif  // CONFIG_FORWARDSKIP
 }
@@ -1358,8 +1472,12 @@
     const AV1_COMMON *cm,
 #endif  // CONFIG_FORWARDSKIP
     const MACROBLOCK *x, const int plane, const int block,
-    const TX_SIZE tx_size, const TX_TYPE tx_type, const TXB_CTX *const txb_ctx,
-    const int reduced_tx_set_used, const int adjust_eob) {
+    const TX_SIZE tx_size, const TX_TYPE tx_type,
+#if CONFIG_CROSS_CHROMA_TX
+    const CctxType cctx_type,
+#endif  // CONFIG_CROSS_CHROMA_TX
+    const TXB_CTX *const txb_ctx, const int reduced_tx_set_used,
+    const int adjust_eob) {
   const struct macroblock_plane *p = &x->plane[plane];
   int eob = p->eobs[block];
 
@@ -1378,22 +1496,27 @@
   const PLANE_TYPE plane_type = get_plane_type(plane);
   const LV_MAP_COEFF_COST *const coeff_costs =
       &x->coeff_costs.coeff_costs[txs_ctx][plane_type];
+  const MACROBLOCKD *const xd = &x->e_mbd;
   if (eob == 0) {
 #if CONFIG_CONTEXT_DERIVATION
     int txb_skip_ctx = txb_ctx->txb_skip_ctx;
+    int skip_cost = 0;
     if (plane == AOM_PLANE_Y || plane == AOM_PLANE_U) {
-      return coeff_costs->txb_skip_cost[txb_skip_ctx][1];
+      skip_cost += coeff_costs->txb_skip_cost[txb_skip_ctx][1];
     } else {
       txb_skip_ctx +=
           (x->plane[AOM_PLANE_U].eobs[block] ? V_TXB_SKIP_CONTEXT_OFFSET : 0);
-      return coeff_costs->v_txb_skip_cost[txb_skip_ctx][1];
+      skip_cost += coeff_costs->v_txb_skip_cost[txb_skip_ctx][1];
     }
 #else
-    return coeff_costs->txb_skip_cost[txb_ctx->txb_skip_ctx][1];
+    skip_cost += coeff_costs->txb_skip_cost[txb_ctx->txb_skip_ctx][1];
 #endif  // CONFIG_CONTEXT_DERIVATION
+#if CONFIG_CROSS_CHROMA_TX
+    skip_cost += get_cctx_type_cost(x, xd, plane, tx_size, block, cctx_type);
+#endif  // CONFIG_CROSS_CHROMA_TX
+    return skip_cost;
   }
 
-  const MACROBLOCKD *const xd = &x->e_mbd;
 #if CONFIG_IST
   const TX_CLASS tx_class = tx_type_to_class[get_primary_tx_type(tx_type)];
 #else
@@ -1405,7 +1528,11 @@
       cm,
 #endif  // CONFIG_FORWARDSKIP
       x, plane, block, tx_size, txb_ctx, eob, plane_type, coeff_costs, xd,
-      tx_type, tx_class, reduced_tx_set_used);
+      tx_type,
+#if CONFIG_CROSS_CHROMA_TX
+      cctx_type,
+#endif  // CONFIG_CROSS_CHROMA_TX
+      tx_class, reduced_tx_set_used);
 }
 
 static AOM_FORCE_INLINE int get_two_coeff_cost_simple(
@@ -1823,6 +1950,9 @@
 
 int av1_optimize_txb_new(const struct AV1_COMP *cpi, MACROBLOCK *x, int plane,
                          int block, TX_SIZE tx_size, TX_TYPE tx_type,
+#if CONFIG_CROSS_CHROMA_TX
+                         CctxType cctx_type,
+#endif  // CONFIG_CROSS_CHROMA_TX
                          const TXB_CTX *const txb_ctx, int *rate_cost,
                          int sharpness) {
   MACROBLOCKD *xd = &x->e_mbd;
@@ -2033,6 +2163,10 @@
   p->txb_entropy_ctx[block] =
       av1_get_txb_entropy_context(qcoeff, scan_order, p->eobs[block]);
 
+#if CONFIG_CROSS_CHROMA_TX
+  accu_rate += get_cctx_type_cost(x, xd, plane, tx_size, block, cctx_type);
+#endif  // CONFIG_CROSS_CHROMA_TX
+
   *rate_cost = accu_rate;
   return eob;
 }
@@ -2055,6 +2189,32 @@
   return (uint8_t)cul_level;
 }
 
+#if CONFIG_CROSS_CHROMA_TX
+static void update_cctx_type_count(const AV1_COMMON *cm, MACROBLOCKD *xd,
+                                   int blk_row, int blk_col, TX_SIZE tx_size,
+                                   FRAME_COUNTS *counts,
+                                   uint8_t allow_update_cdf) {
+  const MB_MODE_INFO *mbmi = xd->mi[0];
+  const int is_inter = is_inter_block(mbmi, xd->tree_type);
+  FRAME_CONTEXT *fc = xd->tile_ctx;
+#if !CONFIG_ENTROPY_STATS
+  (void)counts;
+#endif  // !CONFIG_ENTROPY_STATS
+  if (((is_inter && CCTX_INTER) || (!is_inter && CCTX_INTRA)) &&
+      cm->quant_params.base_qindex > 0 &&
+      !mbmi->skip_txfm[xd->tree_type == CHROMA_PART] &&
+      !segfeature_active(&cm->seg, mbmi->segment_id, SEG_LVL_SKIP)) {
+    const CctxType cctx_type = av1_get_cctx_type(xd, blk_row, blk_col);
+    if (allow_update_cdf)
+      update_cdf(fc->cctx_type_cdf[txsize_sqr_map[tx_size]], cctx_type,
+                 CCTX_TYPES);
+#if CONFIG_ENTROPY_STATS
+    ++counts->cctx_type[txsize_sqr_map[tx_size]][cctx_type];
+#endif  // CONFIG_ENTROPY_STATS
+  }
+}
+#endif  // CONFIG_CROSS_CHROMA_TX
+
 static void update_tx_type_count(const AV1_COMP *cpi, const AV1_COMMON *cm,
                                  MACROBLOCKD *xd, int blk_row, int blk_col,
                                  int plane, TX_SIZE tx_size,
@@ -2416,6 +2576,7 @@
     return;
   }
 #endif  // CONFIG_FORWARDSKIP
+
   const SCAN_ORDER *const scan_order = get_scan(tx_size, tx_type);
   tran_low_t *tcoeff;
   assert(args->dry_run != DRY_RUN_COSTCOEFFS);
@@ -2430,6 +2591,28 @@
                 0
 #endif  // CONFIG_FORWARDSKIP
     );
+#if CONFIG_CROSS_CHROMA_TX && CCTX_C2_DROPPED
+    if (plane == AOM_PLANE_V) {
+      CctxType cctx_type = av1_get_cctx_type(xd, blk_row, blk_col);
+      if (!keep_chroma_c2(cctx_type)) {
+        assert(eob == 0);
+        CB_COEFF_BUFFER *cb_coef_buff = x->cb_coef_buff;
+        const int txb_offset =
+            x->mbmi_ext_frame
+                ->cb_offset[(plane > 0 && xd->tree_type == CHROMA_PART) ? 1
+                                                                        : 0] /
+            (TX_SIZE_W_MIN * TX_SIZE_H_MIN);
+        uint16_t *eob_txb = cb_coef_buff->eobs[plane] + txb_offset;
+        uint8_t *const entropy_ctx =
+            cb_coef_buff->entropy_ctx[plane] + txb_offset;
+        entropy_ctx[block] = txb_ctx.txb_skip_ctx;
+        eob_txb[block] = 0;
+        av1_set_entropy_contexts(xd, pd, plane, plane_bsize, tx_size, 0,
+                                 blk_col, blk_row);
+        return;
+      }
+    }
+#endif
     const int bwl = get_txb_bwl(tx_size);
     const int width = get_txb_wide(tx_size);
     const int height = get_txb_high(tx_size);
@@ -2475,6 +2658,19 @@
     entropy_ctx[block] = txb_ctx.txb_skip_ctx;
     eob_txb[block] = eob;
 
+#if CONFIG_CROSS_CHROMA_TX
+#if CCTX_C1_NONZERO
+    if (plane == AOM_PLANE_U && eob > 0)
+      update_cctx_type_count(cm, xd, blk_row, blk_col, tx_size, td->counts,
+                             allow_update_cdf);
+#else
+    if (plane == AOM_PLANE_V &&
+        (eob > 0 || x->plane[AOM_PLANE_U].eobs[block] > 0)) {
+      update_cctx_type_count(cm, xd, blk_row, blk_col, tx_size, td->counts,
+                             allow_update_cdf);
+    }
+#endif  // CCTX_C1_NONZERO
+#endif  // CONFIG_CROSS_CHROMA_TX
     if (eob == 0) {
       av1_set_entropy_contexts(xd, pd, plane, plane_bsize, tx_size, 0, blk_col,
                                blk_row);
diff --git a/av1/encoder/encodetxb.h b/av1/encoder/encodetxb.h
index fbd1b49..7e1be18 100644
--- a/av1/encoder/encodetxb.h
+++ b/av1/encoder/encodetxb.h
@@ -78,13 +78,16 @@
  * \param[in]    cpi            Top-level encoder structure
  */
 void av1_free_txb_buf(AV1_COMP *cpi);
+
 /*!\brief Compute the entropy cost of coding coefficients in a transform block.
  *
  * \ingroup coefficient_coding
- *
- * \param[in]    cm                   Top-level structure shared by encoder and
- * decoder
- * \param[in]    x                    Pointer to structure holding the data for
+ */
+#if CONFIG_FORWARDSKIP
+/* \param[in]    cm                   Top-level structure shared by encoder and
+ * decoder*/
+#endif  // CONFIG_FORWARDSKIP
+/* \param[in]    x                    Pointer to structure holding the data for
  the current encoding macroblock.
  * \param[in]    plane                The index of the current plane.
  * \param[in]    block                The index of the current transform block
@@ -92,8 +95,12 @@
  * macroblock. It's defined by number of 4x4 units that have been coded before
  * the currernt transform block.
  * \param[in]    tx_size              The transform size.
- * \param[in]    tx_type              The transform type.
- * \param[in]    txb_ctx              Context info for entropy coding transform
+ * \param[in]    tx_type              The transform type.*/
+#if CONFIG_CROSS_CHROMA_TX
+/* \param[in]    cctx_type            The cross chroma component transform
+ * type*/
+#endif  // CONFIG_CROSS_CHROMA_TX
+/* \param[in]    txb_ctx              Context info for entropy coding transform
  block
  * skip flag (tx_skip) and the sign of DC coefficient (dc_sign).
  * \param[in]    reduced_tx_set_used  Whether the transform type is chosen from
@@ -104,8 +111,11 @@
     const AV1_COMMON *cm,
 #endif  // CONFIG_FORWARDSKIP
     const MACROBLOCK *x, const int plane, const int block,
-    const TX_SIZE tx_size, const TX_TYPE tx_type, const TXB_CTX *const txb_ctx,
-    int reduced_tx_set_used);
+    const TX_SIZE tx_size, const TX_TYPE tx_type,
+#if CONFIG_CROSS_CHROMA_TX
+    const CctxType cctx_type,
+#endif  // CONFIG_CROSS_CHROMA_TX
+    const TXB_CTX *const txb_ctx, int reduced_tx_set_used);
 
 /*!\brief Estimate the entropy cost of coding a transform block using Laplacian
  * distribution.
@@ -123,18 +133,23 @@
  *
  * Compared to \ref av1_cost_coeffs_txb, this function is much faster but less
  * accurate.
- *
- * \param[in]    cm             Top-level structure shared by encoder and
- * decoder
- * \param[in]    x              Pointer to structure holding the data for the
+ */
+#if CONFIG_FORWARDSKIP
+/* \param[in]    cm             Top-level structure shared by encoder and
+ * decoder*/
+#endif  // CONFIG_FORWARDSKIP
+/* \param[in]    x              Pointer to structure holding the data for the
                                 current encoding macroblock
  * \param[in]    plane          The index of the current plane
  * \param[in]    block          The index of the current transform block in the
  * macroblock. It's defined by number of 4x4 units that have been coded before
  * the currernt transform block
  * \param[in]    tx_size        The transform size
- * \param[in]    tx_type        The transform type
- * \param[in]    txb_ctx        Context info for entropy coding transform block
+ * \param[in]    tx_type        The transform type*/
+#if CONFIG_CROSS_CHROMA_TX
+/* \param[in]    cctx_type      The cross chroma component transform type*/
+#endif  // CONFIG_CROSS_CHROMA_TX
+/* \param[in]    txb_ctx        Context info for entropy coding transform block
  * skip flag (tx_skip) and the sign of DC coefficient (dc_sign).
  * \param[in]    reduced_tx_set_used  Whether the transform type is chosen from
  * a reduced set.
@@ -149,8 +164,12 @@
     const AV1_COMMON *cm,
 #endif  // CONFIG_FORWARDSKIP
     const MACROBLOCK *x, const int plane, const int block,
-    const TX_SIZE tx_size, const TX_TYPE tx_type, const TXB_CTX *const txb_ctx,
-    const int reduced_tx_set_used, const int adjust_eob);
+    const TX_SIZE tx_size, const TX_TYPE tx_type,
+#if CONFIG_CROSS_CHROMA_TX
+    const CctxType cctx_type,
+#endif  // CONFIG_CROSS_CHROMA_TX
+    const TXB_CTX *const txb_ctx, const int reduced_tx_set_used,
+    const int adjust_eob);
 
 /*!\brief Estimate the entropy cost of transform coefficients using Laplacian
  * distribution.
@@ -465,6 +484,39 @@
                                             TX_SIZE tx_size, void *arg);
 #endif  // CONFIG_FORWARDSKIP
 
+#if CONFIG_CROSS_CHROMA_TX
+/*!\brief Adjust the magnitude of quantized coefficients to achieve better
+ * rate-distortion (RD) trade-off.
+ *
+ * \ingroup coefficient_coding
+ *
+ * This function goes through each coefficient and greedily choose to lower
+ * the coefficient magnitude by 1 or not based on the RD score.
+ *
+ * The coefficients are processing in reversed scan order.
+ *
+ * Note that, the end of block position (eob) may change if the original last
+ * coefficient is lowered to zero.
+ *
+ * \param[in]    cpi            Top-level encoder structure
+ * \param[in]    x              Pointer to structure holding the data for the
+                                current encoding macroblock
+ * \param[in]    plane          The index of the current plane
+ * \param[in]    block          The index of the current transform block in the
+ * \param[in]    tx_size        The transform size
+ * \param[in]    tx_type        The transform type
+ * \param[in]    cctx_type      The cross chroma component transform type
+ * \param[in]    txb_ctx        Context info for entropy coding transform block
+ * skip flag (tx_skip) and the sign of DC coefficient (dc_sign).
+ * \param[out]   rate_cost      The entropy cost of coding the transform block
+ * after adjustment of coefficients.
+ * \param[in]    sharpness      When sharpness == 1, the function will be less
+ * aggressive toward lowering the magnitude of coefficients.
+ * In this way, the transform block will contain more high-frequency
+ coefficients
+ * and therefore preserve the sharpness of the reconstructed block.
+ */
+#else
 /*!\brief Adjust the magnitude of quantized coefficients to achieve better
  * rate-distortion (RD) trade-off.
  *
@@ -495,8 +547,12 @@
  coefficients
  * and therefore preserve the sharpness of the reconstructed block.
  */
+#endif  // CONFIG_CROSS_CHROMA_TX
 int av1_optimize_txb_new(const struct AV1_COMP *cpi, MACROBLOCK *x, int plane,
                          int block, TX_SIZE tx_size, TX_TYPE tx_type,
+#if CONFIG_CROSS_CHROMA_TX
+                         CctxType cctx_type,
+#endif  // CONFIG_CROSS_CHROMA_TX
                          const TXB_CTX *const txb_ctx, int *rate_cost,
                          int sharpness);
 
@@ -519,6 +575,10 @@
  */
 CB_COEFF_BUFFER *av1_get_cb_coeff_buffer(const struct AV1_COMP *cpi, int mi_row,
                                          int mi_col);
+#if CONFIG_CROSS_CHROMA_TX
+int get_cctx_type_cost(const MACROBLOCK *x, const MACROBLOCKD *xd, int plane,
+                       TX_SIZE tx_size, int block, CctxType cctx_type);
+#endif  // CONFIG_CROSS_CHROMA_TX
 
 #if CONFIG_CONTEXT_DERIVATION
 /*!\brief Returns the entropy cost associated with skipping the current
diff --git a/av1/encoder/hybrid_fwd_txfm.c b/av1/encoder/hybrid_fwd_txfm.c
index 001f672..c04804b 100644
--- a/av1/encoder/hybrid_fwd_txfm.c
+++ b/av1/encoder/hybrid_fwd_txfm.c
@@ -501,7 +501,8 @@
 
 #if CONFIG_CROSS_CHROMA_TX
 void av1_fwd_cross_chroma_tx_block(tran_low_t *coeff_u, tran_low_t *coeff_v,
-                                   TX_SIZE tx_size) {
+                                   TX_SIZE tx_size, CctxType cctx_type) {
+  if (cctx_type == CCTX_NONE) return;
 #if CCTX_DC_ONLY
   const int ncoeffs = 1;
 #else
@@ -511,9 +512,12 @@
   int32_t *src_v = (int32_t *)coeff_v;
   int32_t tmp[2] = { 0, 0 };
 
+  const int angle_idx = cctx_type - CCTX_START;
   for (int i = 0; i < ncoeffs; i++) {
-    tmp[0] = cctx_mtx[0] * src_u[i] + cctx_mtx[1] * src_v[i];
-    tmp[1] = cctx_mtx[2] * src_u[i] + cctx_mtx[3] * src_v[i];
+    tmp[0] =
+        cctx_mtx[angle_idx][0] * src_u[i] + cctx_mtx[angle_idx][1] * src_v[i];
+    tmp[1] =
+        -cctx_mtx[angle_idx][1] * src_u[i] + cctx_mtx[angle_idx][0] * src_v[i];
     src_u[i] = ROUND_POWER_OF_TWO_SIGNED(tmp[0], CCTX_PREC_BITS);
     src_v[i] = ROUND_POWER_OF_TWO_SIGNED(tmp[1], CCTX_PREC_BITS);
   }
diff --git a/av1/encoder/hybrid_fwd_txfm.h b/av1/encoder/hybrid_fwd_txfm.h
index 6bc7ffb..ca086b8 100644
--- a/av1/encoder/hybrid_fwd_txfm.h
+++ b/av1/encoder/hybrid_fwd_txfm.h
@@ -27,7 +27,7 @@
 
 #if CONFIG_CROSS_CHROMA_TX
 void av1_fwd_cross_chroma_tx_block(tran_low_t *dqcoeff_u, tran_low_t *dqcoeff_v,
-                                   TX_SIZE tx_size);
+                                   TX_SIZE tx_size, CctxType cctx_type);
 #endif  // CONFIG_CROSS_CHROMA_TX
 
 #if CONFIG_IST
diff --git a/av1/encoder/intra_mode_search.c b/av1/encoder/intra_mode_search.c
index 2583578..dfb4af5 100644
--- a/av1/encoder/intra_mode_search.c
+++ b/av1/encoder/intra_mode_search.c
@@ -346,9 +346,9 @@
 #endif
   int64_t best_rd_uv[CFL_JOINT_SIGNS][CFL_PRED_PLANES];
   int best_c[CFL_JOINT_SIGNS][CFL_PRED_PLANES];
-#if CONFIG_DEBUG
+#if CONFIG_DEBUG && !(CONFIG_CROSS_CHROMA_TX && CCTX_INTRA)
   int best_rate_uv[CFL_JOINT_SIGNS][CFL_PRED_PLANES];
-#endif  // CONFIG_DEBUG
+#endif  // CONFIG_DEBUG && !(CONFIG_CROSS_CHROMA_TX && CCTX_INTRA)
 
 #if CONFIG_CONTEXT_DERIVATION
   const int skip_trellis = 0;
@@ -375,9 +375,9 @@
     const int alpha_rate = mode_costs->cfl_cost[joint_sign][CFL_PRED_U][0];
     best_rd_uv[joint_sign][CFL_PRED_U] =
         RDCOST(x->rdmult, rd_stats.rate + alpha_rate, rd_stats.dist);
-#if CONFIG_DEBUG
+#if CONFIG_DEBUG && !(CONFIG_CROSS_CHROMA_TX && CCTX_INTRA)
     best_rate_uv[joint_sign][CFL_PRED_U] = rd_stats.rate;
-#endif  // CONFIG_DEBUG
+#endif  // CONFIG_DEBUG && !(CONFIG_CROSS_CHROMA_TX && CCTX_INTRA)
   }
   // Collect RD stats for alpha values other than zero in CFL_PRED_U.
   for (int pn_sign = CFL_SIGN_NEG; pn_sign < CFL_SIGNS; pn_sign++) {
@@ -400,9 +400,9 @@
         if (this_rd >= best_rd_uv[joint_sign][CFL_PRED_U]) continue;
         best_rd_uv[joint_sign][CFL_PRED_U] = this_rd;
         best_c[joint_sign][CFL_PRED_U] = c;
-#if CONFIG_DEBUG
+#if CONFIG_DEBUG && !(CONFIG_CROSS_CHROMA_TX && CCTX_INTRA)
         best_rate_uv[joint_sign][CFL_PRED_U] = rd_stats.rate;
-#endif  // CONFIG_DEBUG
+#endif  // CONFIG_DEBUG && !(CONFIG_CROSS_CHROMA_TX && CCTX_INTRA)
         flag = 2;
         if (best_rd_uv[joint_sign][CFL_PRED_V] == INT64_MAX) continue;
         this_rd += mode_rd + best_rd_uv[joint_sign][CFL_PRED_V];
@@ -439,9 +439,9 @@
       if (this_rd >= best_rd_uv[joint_sign][CFL_PRED_V]) continue;
       best_rd_uv[joint_sign][CFL_PRED_V] = this_rd;
       best_c[joint_sign][CFL_PRED_V] = c;
-#if CONFIG_DEBUG
+#if CONFIG_DEBUG && !(CONFIG_CROSS_CHROMA_TX && CCTX_INTRA)
       best_rate_uv[joint_sign][CFL_PRED_V] = rd_stats.rate;
-#endif  // CONFIG_DEBUG
+#endif  // CONFIG_DEBUG && !(CONFIG_CROSS_CHROMA_TX && CCTX_INTRA)
       flag = 2;
       if (best_rd_uv[joint_sign][CFL_PRED_U] == INT64_MAX) continue;
       this_rd += mode_rd + best_rd_uv[joint_sign][CFL_PRED_U];
@@ -475,9 +475,9 @@
       const int alpha_rate = mode_costs->cfl_cost[joint_sign][plane][0];
       best_rd_uv[joint_sign][plane] =
           RDCOST(x->rdmult, rd_stats.rate + alpha_rate, rd_stats.dist);
-#if CONFIG_DEBUG
+#if CONFIG_DEBUG && !(CONFIG_CROSS_CHROMA_TX && CCTX_INTRA)
       best_rate_uv[joint_sign][plane] = rd_stats.rate;
-#endif  // CONFIG_DEBUG
+#endif  // CONFIG_DEBUG && !(CONFIG_CROSS_CHROMA_TX && CCTX_INTRA)
     }
   }
 
@@ -506,9 +506,9 @@
           if (this_rd >= best_rd_uv[joint_sign][plane]) continue;
           best_rd_uv[joint_sign][plane] = this_rd;
           best_c[joint_sign][plane] = c;
-#if CONFIG_DEBUG
+#if CONFIG_DEBUG && !(CONFIG_CROSS_CHROMA_TX && CCTX_INTRA)
           best_rate_uv[joint_sign][plane] = rd_stats.rate;
-#endif  // CONFIG_DEBUG
+#endif  // CONFIG_DEBUG && !(CONFIG_CROSS_CHROMA_TX && CCTX_INTRA)
           flag = 2;
           if (best_rd_uv[joint_sign][!plane] == INT64_MAX) continue;
           this_rd += mode_rd + best_rd_uv[joint_sign][!plane];
@@ -530,7 +530,7 @@
     ind = (u << CFL_ALPHABET_SIZE_LOG2) + v;
     best_rate_overhead = mode_costs->cfl_cost[best_joint_sign][CFL_PRED_U][u] +
                          mode_costs->cfl_cost[best_joint_sign][CFL_PRED_V][v];
-#if CONFIG_DEBUG
+#if CONFIG_DEBUG && !(CONFIG_CROSS_CHROMA_TX && CCTX_INTRA)
     xd->cfl.rate =
 #if CONFIG_AIMC
         mode_costs->intra_uv_mode_cost[CFL_ALLOWED][uv_context][UV_CFL_PRED] +
@@ -539,7 +539,7 @@
 #endif
         best_rate_overhead + best_rate_uv[best_joint_sign][CFL_PRED_U] +
         best_rate_uv[best_joint_sign][CFL_PRED_V];
-#endif  // CONFIG_DEBUG
+#endif  // CONFIG_DEBUG && !(CONFIG_CROSS_CHROMA_TX && CCTX_INTRA)
   } else {
     best_joint_sign = 0;
   }
@@ -563,6 +563,9 @@
 int64_t av1_rd_pick_intra_sbuv_mode(const AV1_COMP *const cpi, MACROBLOCK *x,
                                     int *rate, int *rate_tokenonly,
                                     int64_t *distortion, int *skippable,
+#if CONFIG_CROSS_CHROMA_TX && CCTX_INTRA
+                                    const PICK_MODE_CONTEXT *ctx,
+#endif  // CONFIG_CROSS_CHROMA_TX && CCTX_INTRA
                                     BLOCK_SIZE bsize, TX_SIZE max_tx_size) {
   const AV1_COMMON *const cm = &cpi->common;
   MACROBLOCKD *xd = &x->e_mbd;
@@ -594,6 +597,9 @@
       // this function everytime we search through uv modes. There is some
       // potential speed up here if we cache the result to avoid redundant
       // computation.
+#if CONFIG_CROSS_CHROMA_TX && CCTX_INTRA
+      // TODO(kslu) fix CFL and apply the pipeline change
+#endif  // CONFIG_CROSS_CHROMA_TX && CCTX_INTRA
       av1_encode_intra_block_plane(cpi, x, mbmi->sb_type[PLANE_TYPE_Y],
                                    AOM_PLANE_Y, DRY_RUN_NORMAL,
                                    cpi->optimize_seg_arr[mbmi->segment_id]);
@@ -669,15 +675,18 @@
                 intra_mode_info_cost_uv(cpi, x, mbmi, bsize, mode_cost);
     if (mode == UV_CFL_PRED) {
       assert(is_cfl_allowed(xd) && intra_mode_cfg->enable_cfl_intra);
-#if CONFIG_DEBUG
+#if CONFIG_DEBUG && !(CONFIG_CROSS_CHROMA_TX && CCTX_INTRA)
       if (!xd->lossless[mbmi->segment_id])
         assert(xd->cfl.rate == tokenonly_rd_stats.rate + mode_cost);
-#endif  // CONFIG_DEBUG
+#endif  // CONFIG_DEBUG && !(CONFIG_CROSS_CHROMA_TX && CCTX_INTRA)
     }
     this_rd = RDCOST(x->rdmult, this_rate, tokenonly_rd_stats.dist);
 
     if (this_rd < best_rd) {
       best_mbmi = *mbmi;
+#if CONFIG_CROSS_CHROMA_TX && CCTX_INTRA
+      av1_copy_array(ctx->cctx_type_map, xd->cctx_type_map, ctx->num_4x4_blk);
+#endif  // CONFIG_CROSS_CHROMA_TX && CCTX_INTRA
       best_rd = this_rd;
       *rate = this_rate;
       *rate_tokenonly = tokenonly_rd_stats.rate;
@@ -710,6 +719,9 @@
   }
 
   *mbmi = best_mbmi;
+#if CONFIG_CROSS_CHROMA_TX && CCTX_INTRA
+  av1_copy_array(xd->cctx_type_map, ctx->cctx_type_map, ctx->num_4x4_blk);
+#endif  // CONFIG_CROSS_CHROMA_TX && CCTX_INTRA
   // Make sure we actually chose a mode
   assert(best_rd < INT64_MAX);
   return best_rd;
@@ -780,7 +792,11 @@
       av1_rd_pick_intra_sbuv_mode(cpi, x, &intra_search_state->rate_uv_intra,
                                   &intra_search_state->rate_uv_tokenonly,
                                   &intra_search_state->dist_uvs,
-                                  &intra_search_state->skip_uvs, bsize, uv_tx);
+                                  &intra_search_state->skip_uvs,
+#if CONFIG_CROSS_CHROMA_TX && CCTX_INTRA
+                                  ctx,
+#endif  // CONFIG_CROSS_CHROMA_TX && CCTX_INTRA
+                                  bsize, uv_tx);
       intra_search_state->mode_uv = mbmi->uv_mode;
       intra_search_state->pmi_uv = *pmi;
       intra_search_state->uv_angle_delta = mbmi->angle_delta[PLANE_TYPE_UV];
@@ -1091,7 +1107,11 @@
       av1_rd_pick_intra_sbuv_mode(cpi, x, &intra_search_state->rate_uv_intra,
                                   &intra_search_state->rate_uv_tokenonly,
                                   &intra_search_state->dist_uvs,
-                                  &intra_search_state->skip_uvs, bsize, uv_tx);
+                                  &intra_search_state->skip_uvs,
+#if CONFIG_CROSS_CHROMA_TX && CCTX_INTRA
+                                  ctx,
+#endif  // CONFIG_CROSS_CHROMA_TX && CCTX_INTRA
+                                  bsize, uv_tx);
       intra_search_state->mode_uv = mbmi->uv_mode;
       if (try_palette) intra_search_state->pmi_uv = *pmi;
       intra_search_state->uv_angle_delta = mbmi->angle_delta[PLANE_TYPE_UV];
diff --git a/av1/encoder/intra_mode_search.h b/av1/encoder/intra_mode_search.h
index a2907d2..f058eb0 100644
--- a/av1/encoder/intra_mode_search.h
+++ b/av1/encoder/intra_mode_search.h
@@ -281,6 +281,39 @@
                                    BLOCK_SIZE bsize, int64_t best_rd,
                                    PICK_MODE_CONTEXT *ctx);
 
+#if CONFIG_CROSS_CHROMA_TX && CCTX_INTRA
+/*!\brief Perform intra-mode search on chroma channels.
+ *
+ * \ingroup intra_mode_search
+ * \callergraph
+ * \callgraph
+ * This function performs intra-mode search on the chroma channels. Just like
+ * \ref av1_rd_pick_intra_sby_mode(), this function searches over palette mode
+ * (filter_intra is not available on chroma planes). Unlike \ref
+ * av1_rd_pick_intra_sby_mode() this function is used by both inter and intra
+ * frames.
+ *
+ * \param[in]    cpi                Top-level encoder structure.
+ * \param[in]    x                  Pointer to structure holding all the data
+ *                                  for the current macroblock.
+ * \param[in]    rate               The total rate needed to predict the current
+ *                                  chroma block.
+ * \param[in]    rate_tokenonly     The rate without the cost of sending the
+ *                                  prediction modes.
+ *                                  chroma block.
+ *                                  after the reconstruction.
+ * \param[in]    distortion         The chroma distortion of the best prediction
+ *                                  after the reconstruction.
+ * \param[in]    skippable          Whether we can skip txfm process.
+ * \param[in]    ctx                Structure to hold the number of 4x4 blks to
+ *                                  copy the tx_type and txfm_skip arrays.
+ * \param[in]    bsize              Current partition block size.
+ * \param[in]    max_tx_size        The maximum tx_size available
+ *
+ * \return Returns the rd_cost of the best uv mode found. This also updates the
+ * mbmi, the rate and distortion, distortion.
+ */
+#else
 /*!\brief Perform intra-mode search on chroma channels.
  *
  * \ingroup intra_mode_search
@@ -310,9 +343,13 @@
  * \return Returns the rd_cost of the best uv mode found. This also updates the
  * mbmi, the rate and distortion, distortion.
  */
+#endif  // CONFIG_CROSS_CHROMA_TX && CCTX_INTRA
 int64_t av1_rd_pick_intra_sbuv_mode(const AV1_COMP *const cpi, MACROBLOCK *x,
                                     int *rate, int *rate_tokenonly,
                                     int64_t *distortion, int *skippable,
+#if CONFIG_CROSS_CHROMA_TX && CCTX_INTRA
+                                    const PICK_MODE_CONTEXT *ctx,
+#endif  // CONFIG_CROSS_CHROMA_TX && CCTX_INTRA
                                     BLOCK_SIZE bsize, TX_SIZE max_tx_size);
 
 /*! \brief Return the number of colors in src. Used by palette mode.
diff --git a/av1/encoder/partition_search.c b/av1/encoder/partition_search.c
index 5cdd990..d102d55 100644
--- a/av1/encoder/partition_search.c
+++ b/av1/encoder/partition_search.c
@@ -354,8 +354,17 @@
     }
     mbmi->skip_txfm[xd->tree_type == CHROMA_PART] = 1;
     for (int plane = plane_start; plane < plane_end; ++plane) {
+#if CONFIG_CROSS_CHROMA_TX && CCTX_INTRA
+      if (plane == AOM_PLANE_Y)
+        av1_encode_intra_block_plane(cpi, x, bsize, plane, dry_run,
+                                     cpi->optimize_seg_arr[mbmi->segment_id]);
+      else if (plane == AOM_PLANE_U)
+        av1_encode_intra_block_joint_uv(
+            cpi, x, bsize, dry_run, cpi->optimize_seg_arr[mbmi->segment_id]);
+#else
       av1_encode_intra_block_plane(cpi, x, bsize, plane, dry_run,
                                    cpi->optimize_seg_arr[mbmi->segment_id]);
+#endif  // CONFIG_CROSS_CHROMA_TX && CCTX_INTRA
     }
 
     // If there is at least one lossless segment, force the skip for intra
@@ -744,6 +753,9 @@
 
   // Sets up the tx_type_map buffer in MACROBLOCKD.
   xd->tx_type_map = txfm_info->tx_type_map_;
+#if CONFIG_CROSS_CHROMA_TX
+  xd->cctx_type_map = txfm_info->cctx_type_map_;
+#endif  // CONFIG_CROSS_CHROMA_TX
   xd->tx_type_map_stride = mi_size_wide[bsize];
 
   for (i = 0; i < num_planes; ++i) {
diff --git a/av1/encoder/rd.c b/av1/encoder/rd.c
index 73684e0..9eaec85 100644
--- a/av1/encoder/rd.c
+++ b/av1/encoder/rd.c
@@ -320,6 +320,13 @@
   }
 #endif  // CONFIG_IST
 
+#if CONFIG_CROSS_CHROMA_TX
+  for (i = 0; i < EXT_TX_SIZES; ++i) {
+    av1_cost_tokens_from_cdf(mode_costs->cctx_type_cost[i],
+                             fc->cctx_type_cdf[i], NULL);
+  }
+#endif  // CONFIG_CROSS_CHROMA_TX
+
   if (!frame_is_intra_only(cm)) {
     for (i = 0; i < COMP_INTER_CONTEXTS; ++i) {
       av1_cost_tokens_from_cdf(mode_costs->comp_inter_cost[i],
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 848c589..7baa75e 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -1625,6 +1625,9 @@
   RD_STATS best_rd_stats, best_rd_stats_y, best_rd_stats_uv;
   uint8_t best_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE];
   TX_TYPE best_tx_type_map[MAX_MIB_SIZE * MAX_MIB_SIZE];
+#if CONFIG_CROSS_CHROMA_TX
+  TX_TYPE best_cctx_type_map[MAX_MIB_SIZE * MAX_MIB_SIZE];
+#endif  // CONFIG_CROSS_CHROMA_TX
   const int rate_mv0 = *rate_mv;
   const int interintra_allowed =
       cm->seq_params.enable_interintra_compound && is_interintra_allowed(mbmi);
@@ -1940,6 +1943,10 @@
       memcpy(best_blk_skip, txfm_info->blk_skip,
              sizeof(txfm_info->blk_skip[0]) * xd->height * xd->width);
       av1_copy_array(best_tx_type_map, xd->tx_type_map, xd->height * xd->width);
+#if CONFIG_CROSS_CHROMA_TX
+      av1_copy_array(best_cctx_type_map, xd->cctx_type_map,
+                     xd->height * xd->width);
+#endif  // CONFIG_CROSS_CHROMA_TX
       best_xskip_txfm = mbmi->skip_txfm[xd->tree_type == CHROMA_PART];
     }
   }
@@ -1958,6 +1965,9 @@
   memcpy(txfm_info->blk_skip, best_blk_skip,
          sizeof(txfm_info->blk_skip[0]) * xd->height * xd->width);
   av1_copy_array(xd->tx_type_map, best_tx_type_map, xd->height * xd->width);
+#if CONFIG_CROSS_CHROMA_TX
+  av1_copy_array(xd->cctx_type_map, best_cctx_type_map, xd->height * xd->width);
+#endif  // CONFIG_CROSS_CHROMA_TX
   txfm_info->skip_txfm = best_xskip_txfm;
 
   restore_dst_buf(xd, *orig_dst, num_planes);
@@ -3043,6 +3053,9 @@
   int64_t best_rd = INT64_MAX;
   uint8_t best_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE];
   TX_TYPE best_tx_type_map[MAX_MIB_SIZE * MAX_MIB_SIZE];
+#if CONFIG_CROSS_CHROMA_TX
+  CctxType best_cctx_type_map[MAX_MIB_SIZE * MAX_MIB_SIZE];
+#endif  // CONFIG_CROSS_CHROMA_TX
   MB_MODE_INFO best_mbmi = *mbmi;
   int best_xskip_txfm = 0;
   int64_t newmv_ret_val = INT64_MAX;
@@ -3315,6 +3328,10 @@
                  sizeof(best_blk_skip[0]) * xd->height * xd->width);
           av1_copy_array(best_tx_type_map, xd->tx_type_map,
                          xd->height * xd->width);
+#if CONFIG_CROSS_CHROMA_TX
+          av1_copy_array(best_cctx_type_map, xd->cctx_type_map,
+                         xd->height * xd->width);
+#endif  // CONFIG_CROSS_CHROMA_TX
           motion_mode_cand->rate_mv = rate_mv;
           motion_mode_cand->rate2_nocoeff = rate2_nocoeff;
         }
@@ -3343,6 +3360,9 @@
   memcpy(txfm_info->blk_skip, best_blk_skip,
          sizeof(best_blk_skip[0]) * xd->height * xd->width);
   av1_copy_array(xd->tx_type_map, best_tx_type_map, xd->height * xd->width);
+#if CONFIG_CROSS_CHROMA_TX
+  av1_copy_array(xd->cctx_type_map, best_cctx_type_map, xd->height * xd->width);
+#endif  // CONFIG_CROSS_CHROMA_TX
 
   rd_stats->rdcost = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
   assert(av1_check_newmv_joint_nonzero(cm, x));
@@ -3548,6 +3568,10 @@
   uint8_t best_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE] = { 0 };
   TX_TYPE best_tx_type_map[MAX_MIB_SIZE * MAX_MIB_SIZE];
   av1_copy_array(best_tx_type_map, xd->tx_type_map, ctx->num_4x4_blk);
+#if CONFIG_CROSS_CHROMA_TX
+  TX_TYPE best_cctx_type_map[MAX_MIB_SIZE * MAX_MIB_SIZE];
+  av1_copy_array(best_cctx_type_map, xd->cctx_type_map, ctx->num_4x4_blk);
+#endif  // CONFIG_CROSS_CHROMA_TX
 
   FULLPEL_MOTION_SEARCH_PARAMS fullms_params;
   const search_site_config *lookahead_search_sites =
@@ -3840,6 +3864,10 @@
       memcpy(best_blk_skip, txfm_info->blk_skip,
              sizeof(txfm_info->blk_skip[0]) * xd->height * xd->width);
       av1_copy_array(best_tx_type_map, xd->tx_type_map, xd->height * xd->width);
+#if CONFIG_CROSS_CHROMA_TX
+      av1_copy_array(best_cctx_type_map, xd->cctx_type_map,
+                     xd->height * xd->width);
+#endif  // CONFIG_CROSS_CHROMA_TX
     }
   }
   *mbmi = best_mbmi;
@@ -3855,6 +3883,9 @@
   memcpy(txfm_info->blk_skip, best_blk_skip,
          sizeof(txfm_info->blk_skip[0]) * xd->height * xd->width);
   av1_copy_array(xd->tx_type_map, best_tx_type_map, ctx->num_4x4_blk);
+#if CONFIG_CROSS_CHROMA_TX
+  av1_copy_array(xd->cctx_type_map, best_cctx_type_map, ctx->num_4x4_blk);
+#endif  // CONFIG_CROSS_CHROMA_TX
 #if CONFIG_RD_DEBUG
   mbmi->rd_stats = *rd_stats;
 #endif
@@ -3909,8 +3940,12 @@
       }
       const TX_SIZE max_uv_tx_size = av1_get_tx_size(AOM_PLANE_U, xd);
       av1_rd_pick_intra_sbuv_mode(cpi, x, &rate_uv, &rate_uv_tokenonly,
-                                  &dist_uv, &uv_skip_txfm, bsize,
-                                  max_uv_tx_size);
+                                  &dist_uv, &uv_skip_txfm,
+
+#if CONFIG_CROSS_CHROMA_TX && CCTX_INTRA
+                                  ctx,
+#endif  // CONFIG_CROSS_CHROMA_TX && CCTX_INTRA
+                                  bsize, max_uv_tx_size);
     }
 
     // Intra block is always coded as non-skip
@@ -4178,6 +4213,9 @@
         memcpy(ctx->blk_skip, txfm_info->blk_skip,
                sizeof(txfm_info->blk_skip[0]) * ctx->num_4x4_blk);
         av1_copy_array(ctx->tx_type_map, xd->tx_type_map, ctx->num_4x4_blk);
+#if CONFIG_CROSS_CHROMA_TX
+        av1_copy_array(ctx->cctx_type_map, xd->cctx_type_map, ctx->num_4x4_blk);
+#endif  // CONFIG_CROSS_CHROMA_TX
         search_state->best_mode_skippable = 0;
         search_state->best_skip2 = 0;
         search_state->best_rate_y =
@@ -4551,6 +4589,9 @@
         *best_mbmode = *mbmi;
         av1_copy_array(ctx->blk_skip, txfm_info->blk_skip, ctx->num_4x4_blk);
         av1_copy_array(ctx->tx_type_map, xd->tx_type_map, ctx->num_4x4_blk);
+#if CONFIG_CROSS_CHROMA_TX
+        av1_copy_array(ctx->cctx_type_map, xd->cctx_type_map, ctx->num_4x4_blk);
+#endif  // CONFIG_CROSS_CHROMA_TX
         rd_cost->rate = this_rate;
         rd_cost->dist = rd_stats_y.dist + rd_stats_uv.dist;
         rd_cost->sse = rd_stats_y.sse + rd_stats_uv.sse;
@@ -5737,6 +5778,9 @@
   memcpy(ctx->blk_skip, txfm_info->blk_skip,
          sizeof(txfm_info->blk_skip[0]) * ctx->num_4x4_blk);
   av1_copy_array(ctx->tx_type_map, xd->tx_type_map, ctx->num_4x4_blk);
+#if CONFIG_CROSS_CHROMA_TX
+  av1_copy_array(ctx->cctx_type_map, xd->cctx_type_map, ctx->num_4x4_blk);
+#endif  // CONFIG_CROSS_CHROMA_TX
 }
 
 // Find the best RD for a reference frame (among single reference modes)
@@ -6876,6 +6920,9 @@
       memcpy(ctx->blk_skip, txfm_info->blk_skip,
              sizeof(txfm_info->blk_skip[0]) * ctx->num_4x4_blk);
       av1_copy_array(ctx->tx_type_map, xd->tx_type_map, ctx->num_4x4_blk);
+#if CONFIG_CROSS_CHROMA_TX
+      av1_copy_array(ctx->cctx_type_map, xd->cctx_type_map, ctx->num_4x4_blk);
+#endif  // CONFIG_CROSS_CHROMA_TX
     }
   }
 
@@ -6933,6 +6980,9 @@
         memcpy(ctx->blk_skip, txfm_info->blk_skip,
                sizeof(txfm_info->blk_skip[0]) * ctx->num_4x4_blk);
         av1_copy_array(ctx->tx_type_map, xd->tx_type_map, ctx->num_4x4_blk);
+#if CONFIG_CROSS_CHROMA_TX
+        av1_copy_array(ctx->cctx_type_map, xd->cctx_type_map, ctx->num_4x4_blk);
+#endif  // CONFIG_CROSS_CHROMA_TX
         ctx->rd_stats.skip_txfm = mbmi->skip_txfm[xd->tree_type == CHROMA_PART];
       }
     }
diff --git a/av1/encoder/tx_search.c b/av1/encoder/tx_search.c
index 4db1c8c..4e33bbb 100644
--- a/av1/encoder/tx_search.c
+++ b/av1/encoder/tx_search.c
@@ -520,6 +520,9 @@
   const int n4 = bsize_to_num_blk(bsize);
   const TX_SIZE tx_size = max_txsize_rect_lookup[bsize];
   memset(xd->tx_type_map, DCT_DCT, sizeof(xd->tx_type_map[0]) * n4);
+#if CONFIG_CROSS_CHROMA_TX
+  memset(xd->cctx_type_map, CCTX_NONE, sizeof(xd->cctx_type_map[0]) * n4);
+#endif  // CONFIG_CROSS_CHROMA_TX
   memset(mbmi->inter_tx_size, tx_size, sizeof(mbmi->inter_tx_size));
 #if CONFIG_NEW_TX_PARTITION
   memset(mbmi->tx_partition_type, TX_PARTITION_NONE,
@@ -1112,6 +1115,12 @@
   if (!is_inter && best_eob &&
       (blk_row + tx_size_high_unit[tx_size] < mi_size_high[plane_bsize] ||
        blk_col + tx_size_wide_unit[tx_size] < mi_size_wide[plane_bsize])) {
+#if CONFIG_CROSS_CHROMA_TX
+    CctxType cctx_type = av1_get_cctx_type(xd, blk_row, blk_col);
+#if !CCTX_INTRA
+    assert(cctx_type == CCTX_NONE);
+#endif  // !CCTX_INTRA
+#endif  // CONFIG_CROSS_CHROMA_TX
     // if the quantized coefficients are stored in the dqcoeff buffer, we don't
     // need to do transform and quantization again.
     if (do_quant) {
@@ -1121,7 +1130,11 @@
 #if CONFIG_IST
                       plane,
 #endif
-                      tx_size, best_tx_type, &txfm_param_intra);
+                      tx_size, best_tx_type,
+#if CONFIG_CROSS_CHROMA_TX
+                      cctx_type,
+#endif  // CONFIG_CROSS_CHROMA_TX
+                      &txfm_param_intra);
       av1_setup_quant(tx_size, !skip_trellis,
                       skip_trellis
                           ? (USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B
@@ -1137,15 +1150,38 @@
           x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param_intra,
           &quant_param_intra);
       if (quant_param_intra.use_optimize_b) {
-        av1_optimize_b(cpi, x, plane, block, tx_size, best_tx_type, txb_ctx,
-                       rate_cost);
+        av1_optimize_b(cpi, x, plane, block, tx_size, best_tx_type,
+#if CONFIG_CROSS_CHROMA_TX
+                       cctx_type,
+#endif  // CONFIG_CROSS_CHROMA_TX
+                       txb_ctx, rate_cost);
       }
     }
 
-    // TODO(kslu) apply inv cctx for u plane once it is needed for intra
-    inverse_transform_block_facade(x, plane, block, blk_row, blk_col,
-                                   x->plane[plane].eobs[block],
-                                   cm->features.reduced_tx_set_used);
+#if CONFIG_CROSS_CHROMA_TX && CCTX_INTRA
+    // In CONFIG_CROSS_CHROMA_TX, reconstruction for U plane relies on dqcoeffs
+    // of V plane, so the below operators for U are performed together with V
+    // once dqcoeffs of V are obtained.
+    if (plane == AOM_PLANE_V) {
+      tran_low_t *dqcoeff_u =
+          x->plane[AOM_PLANE_U].dqcoeff + BLOCK_OFFSET(block);
+      tran_low_t *dqcoeff_v =
+          x->plane[AOM_PLANE_V].dqcoeff + BLOCK_OFFSET(block);
+      const int max_uv_eob = AOMMAX(x->plane[AOM_PLANE_U].eobs[block],
+                                    x->plane[AOM_PLANE_V].eobs[block]);
+      av1_inv_cross_chroma_tx_block(dqcoeff_u, dqcoeff_v, tx_size, cctx_type);
+      inverse_transform_block_facade(x, AOM_PLANE_U, block, blk_row, blk_col,
+                                     max_uv_eob,
+                                     cm->features.reduced_tx_set_used);
+    }
+    if (plane != AOM_PLANE_U) {
+#endif  // CONFIG_CROSS_CHROMA_TX && CCTX_INTRA
+      inverse_transform_block_facade(x, plane, block, blk_row, blk_col,
+                                     x->plane[plane].eobs[block],
+                                     cm->features.reduced_tx_set_used);
+#if CONFIG_CROSS_CHROMA_TX && CCTX_INTRA
+    }
+#endif  // CONFIG_CROSS_CHROMA_TX && CCTX_INTRA
 
     // This may happen because of hash collision. The eob stored in the hash
     // table is non-zero, but the real eob is zero. We need to make sure tx_type
@@ -1218,7 +1254,7 @@
   const int dst_idx = (blk_row * dst_stride + blk_col) << MI_SIZE_LOG2;
   const uint8_t *src = &x->plane[plane].src.buf[src_idx];
   const uint8_t *dst = &xd->plane[plane].dst.buf[dst_idx];
-#if CONFIG_IST || CONFIG_CROSS_CHROMA_TX
+#if CONFIG_IST
   tran_low_t *dqcoeff = p->dqcoeff + BLOCK_OFFSET(block);
 #else
   const tran_low_t *dqcoeff = p->dqcoeff + BLOCK_OFFSET(block);
@@ -1237,12 +1273,6 @@
   const PLANE_TYPE plane_type = get_plane_type(plane);
   TX_TYPE tx_type = av1_get_tx_type(xd, plane_type, blk_row, blk_col, tx_size,
                                     cpi->common.features.reduced_tx_set_used);
-#if CONFIG_CROSS_CHROMA_TX
-  if (is_inter_block(xd->mi[0], xd->tree_type) && plane == AOM_PLANE_U) {
-    tran_low_t *dqcoeff_v = x->plane[AOM_PLANE_V].dqcoeff + BLOCK_OFFSET(block);
-    av1_inv_cross_chroma_tx_block(dqcoeff, dqcoeff_v, tx_size);
-  }
-#endif  // CONFIG_CROSS_CHROMA_TX
   av1_inverse_transform_block(xd, dqcoeff, plane, tx_type, tx_size, recon,
                               MAX_TX_SIZE, eob,
                               cpi->common.features.reduced_tx_set_used);
@@ -1251,6 +1281,81 @@
                          blk_row, blk_col, plane_bsize, tx_bsize);
 }
 
+#if CONFIG_CROSS_CHROMA_TX
+// Evaluate U and V distortion jointly for cross chroma component transform
+// search.
+static INLINE int64_t joint_uv_dist_block_px_domain(
+    const AV1_COMP *cpi, MACROBLOCK *x, BLOCK_SIZE plane_bsize, int block,
+    int blk_row, int blk_col, TX_SIZE tx_size) {
+  MACROBLOCKD *const xd = &x->e_mbd;
+  const struct macroblock_plane *const p_u = &x->plane[AOM_PLANE_U];
+  const struct macroblock_plane *const p_v = &x->plane[AOM_PLANE_V];
+  const struct macroblockd_plane *const pd_u = &xd->plane[AOM_PLANE_U];
+  const struct macroblockd_plane *const pd_v = &xd->plane[AOM_PLANE_V];
+  const uint16_t max_uv_eob = AOMMAX(p_u->eobs[block], p_v->eobs[block]);
+  const int eob_max = av1_get_max_eob(tx_size);
+  const BLOCK_SIZE tx_bsize = txsize_to_bsize[tx_size];
+  const int bsw = block_size_wide[tx_bsize];
+  const int bsh = block_size_high[tx_bsize];
+  // Scale the transform block index to pixel unit.
+  const int src_idx_u = (blk_row * p_u->src.stride + blk_col) << MI_SIZE_LOG2;
+  const int src_idx_v = (blk_row * p_v->src.stride + blk_col) << MI_SIZE_LOG2;
+  const int dst_idx_u = (blk_row * pd_u->dst.stride + blk_col) << MI_SIZE_LOG2;
+  const int dst_idx_v = (blk_row * pd_v->dst.stride + blk_col) << MI_SIZE_LOG2;
+  const uint8_t *src_u = &p_u->src.buf[src_idx_u];
+  const uint8_t *src_v = &p_v->src.buf[src_idx_v];
+  const uint8_t *dst_u = &pd_u->dst.buf[dst_idx_u];
+  const uint8_t *dst_v = &pd_v->dst.buf[dst_idx_v];
+  // p_u->dqcoeff and p_v->dqcoeff must remain unchanged here because the best
+  // dqcoeff in the CCTX domain may be used in the search later.
+  DECLARE_ALIGNED(32, tran_low_t, tmp_dqcoeff_u[MAX_TX_SQUARE]);
+  DECLARE_ALIGNED(32, tran_low_t, tmp_dqcoeff_v[MAX_TX_SQUARE]);
+  memcpy(tmp_dqcoeff_u, p_u->dqcoeff + BLOCK_OFFSET(block),
+         sizeof(tran_low_t) * eob_max);
+  memcpy(tmp_dqcoeff_v, p_v->dqcoeff + BLOCK_OFFSET(block),
+         sizeof(tran_low_t) * eob_max);
+
+#if CCTX_C1_NONZERO
+  assert(p_u->eobs[block] > 0);
+#endif
+  assert(cpi != NULL);
+  assert(tx_size_wide_log2[0] == tx_size_high_log2[0]);
+
+  uint8_t *recon_u, *recon_v;
+  DECLARE_ALIGNED(16, uint16_t, recon16_u[MAX_TX_SQUARE]);
+  DECLARE_ALIGNED(16, uint16_t, recon16_v[MAX_TX_SQUARE]);
+
+  recon_u = CONVERT_TO_BYTEPTR(recon16_u);
+  recon_v = CONVERT_TO_BYTEPTR(recon16_v);
+  aom_highbd_convolve_copy(CONVERT_TO_SHORTPTR(dst_u), pd_u->dst.stride,
+                           CONVERT_TO_SHORTPTR(recon_u), MAX_TX_SIZE, bsw, bsh);
+  aom_highbd_convolve_copy(CONVERT_TO_SHORTPTR(dst_v), pd_v->dst.stride,
+                           CONVERT_TO_SHORTPTR(recon_v), MAX_TX_SIZE, bsw, bsh);
+
+  CctxType cctx_type = av1_get_cctx_type(xd, blk_row, blk_col);
+  TX_TYPE tx_type =
+      av1_get_tx_type(xd, PLANE_TYPE_UV, blk_row, blk_col, tx_size,
+                      cpi->common.features.reduced_tx_set_used);
+  av1_inv_cross_chroma_tx_block(tmp_dqcoeff_u, tmp_dqcoeff_v, tx_size,
+                                cctx_type);
+  // TODO(kslu): handle transform domain eobs in addition to cctx domain eobs
+  av1_inverse_transform_block(xd, tmp_dqcoeff_u, AOM_PLANE_U, tx_type, tx_size,
+                              recon_u, MAX_TX_SIZE, max_uv_eob,
+                              cpi->common.features.reduced_tx_set_used);
+  av1_inverse_transform_block(xd, tmp_dqcoeff_v, AOM_PLANE_V, tx_type, tx_size,
+                              recon_v, MAX_TX_SIZE, max_uv_eob,
+                              cpi->common.features.reduced_tx_set_used);
+
+  int64_t dist_u =
+      pixel_dist(cpi, x, AOM_PLANE_U, src_u, p_u->src.stride, recon_u,
+                 MAX_TX_SIZE, blk_row, blk_col, plane_bsize, tx_bsize);
+  int64_t dist_v =
+      pixel_dist(cpi, x, AOM_PLANE_V, src_v, p_v->src.stride, recon_v,
+                 MAX_TX_SIZE, blk_row, blk_col, plane_bsize, tx_bsize);
+  return 16 * (dist_u + dist_v);
+}
+#endif  // CONFIG_CROSS_CHROMA_TX
+
 static uint32_t get_intra_txb_hash(MACROBLOCK *x, int plane, int blk_row,
                                    int blk_col, BLOCK_SIZE plane_bsize,
                                    TX_SIZE tx_size) {
@@ -1408,7 +1513,11 @@
 #if CONFIG_IST
                   plane,
 #endif
-                  tx_size, DCT_DCT, &txfm_param);
+                  tx_size, DCT_DCT,
+#if CONFIG_CROSS_CHROMA_TX
+                  CCTX_NONE,
+#endif  // CONFIG_CROSS_CHROMA_TX
+                  &txfm_param);
   av1_setup_quant(tx_size, 1, AV1_XFORM_QUANT_B, cpi->oxcf.q_cfg.quant_b_adapt,
                   &quant_param);
   int tx_type;
@@ -1440,7 +1549,11 @@
 #if CONFIG_FORWARDSKIP
         cm,
 #endif  // CONFIG_FORWARDSKIP
-        x, plane, block, tx_size, tx_type, txb_ctx, reduced_tx_set_used, 0);
+        x, plane, block, tx_size, tx_type,
+#if CONFIG_CROSS_CHROMA_TX
+        CCTX_NONE,
+#endif  // CONFIG_CROSS_CHROMA_TX
+        txb_ctx, reduced_tx_set_used, 0);
 
     rds_h[idx] = RDCOST(x->rdmult, rate_cost, dist);
 
@@ -1477,7 +1590,11 @@
 #if CONFIG_FORWARDSKIP
         cm,
 #endif  // CONFIG_FORWARDSKIP
-        x, plane, block, tx_size, tx_type, txb_ctx, reduced_tx_set_used, 0);
+        x, plane, block, tx_size, tx_type,
+#if CONFIG_CROSS_CHROMA_TX
+        CCTX_NONE,
+#endif  // CONFIG_CROSS_CHROMA_TX
+        txb_ctx, reduced_tx_set_used, 0);
 
     rds_v[idx] = RDCOST(x->rdmult, rate_cost, dist);
 
@@ -1558,7 +1675,11 @@
 #if CONFIG_IST
                   plane,
 #endif
-                  tx_size, DCT_DCT, &txfm_param);
+                  tx_size, DCT_DCT,
+#if CONFIG_CROSS_CHROMA_TX
+                  CCTX_NONE,
+#endif  // CONFIG_CROSS_CHROMA_TX
+                  &txfm_param);
   av1_setup_quant(tx_size, 1, AV1_XFORM_QUANT_B, cpi->oxcf.q_cfg.quant_b_adapt,
                   &quant_param);
 
@@ -1589,7 +1710,11 @@
 #if CONFIG_FORWARDSKIP
         cm,
 #endif  // CONFIG_FORWARDSKIP
-        x, plane, block, tx_size, tx_type, txb_ctx, reduced_tx_set_used, 0);
+        x, plane, block, tx_size, tx_type,
+#if CONFIG_CROSS_CHROMA_TX
+        CCTX_NONE,
+#endif  // CONFIG_CROSS_CHROMA_TX
+        txb_ctx, reduced_tx_set_used, 0);
     // tx domain dist
     dist_block_tx_domain(x, plane, block, tx_size, &dist, &sse);
 
@@ -2121,6 +2246,13 @@
     }
     assert(num_allowed > 0);
 
+#if CONFIG_DEBUG && CONFIG_CROSS_CHROMA_TX
+    if (plane) {
+      const CctxType cctx_type = av1_get_cctx_type(xd, blk_row, blk_col);
+      assert(cctx_type == CCTX_NONE);
+    }
+#endif  // CONFIG_DEBUG && CONFIG_CROSS_CHROMA_TX
+
     if (num_allowed > 2 && cpi->sf.tx_sf.tx_type_search.prune_tx_type_est_rd) {
       int pf = prune_factors[txfm_params->prune_2d_txfm_mode];
       int mf = mul_factors[txfm_params->prune_2d_txfm_mode];
@@ -2206,6 +2338,9 @@
     const AV1_COMMON *cm,
 #endif  // CONFIG_FORWARDSKIP
     MACROBLOCK *x, int plane, int block, TX_SIZE tx_size, const TX_TYPE tx_type,
+#if CONFIG_CROSS_CHROMA_TX
+    const CctxType cctx_type,
+#endif  // CONFIG_CROSS_CHROMA_TX
     const TXB_CTX *const txb_ctx, int reduced_tx_set_used) {
 #if TXCOEFF_COST_TIMER
   struct aom_usec_timer timer;
@@ -2215,7 +2350,11 @@
 #if CONFIG_FORWARDSKIP
       cm,
 #endif  // CONFIG_FORWARDSKIP
-      x, plane, block, tx_size, tx_type, txb_ctx, reduced_tx_set_used);
+      x, plane, block, tx_size, tx_type,
+#if CONFIG_CROSS_CHROMA_TX
+      cctx_type,
+#endif  // CONFIG_CROSS_CHROMA_TX
+      txb_ctx, reduced_tx_set_used);
 #if TXCOEFF_COST_TIMER
   AV1_COMMON *tmp_cm = (AV1_COMMON *)&cpi->common;
   aom_usec_timer_mark(&timer);
@@ -2514,7 +2653,11 @@
 #if CONFIG_IST
                   plane,
 #endif
-                  tx_size, DCT_DCT, &txfm_param);
+                  tx_size, DCT_DCT,
+#if CONFIG_CROSS_CHROMA_TX
+                  CCTX_NONE,
+#endif  // CONFIG_CROSS_CHROMA_TX
+                  &txfm_param);
 
 #if CONFIG_FORWARDSKIP
   const int xform_quant_b =
@@ -2608,60 +2751,15 @@
       RD_STATS this_rd_stats;
       av1_invalid_rd_stats(&this_rd_stats);
 
-#if CONFIG_CROSS_CHROMA_TX
-      if (is_inter_block(mbmi, xd->tree_type)) {
-        switch (plane) {
-          case AOM_PLANE_Y:
-            if (!dc_only_blk) {
+      if (!dc_only_blk)
 #if CONFIG_IST
-              av1_xform(x, AOM_PLANE_Y, block, blk_row, blk_col, plane_bsize,
-                        &txfm_param, 1);
-#else
-              av1_xform(x, AOM_PLANE_Y, block, blk_row, blk_col, plane_bsize,
-                        &txfm_param);
-#endif
-            } else {
-              av1_xform_dc_only(x, AOM_PLANE_Y, block, &txfm_param,
-                                per_px_mean);
-            }
-            break;
-          case AOM_PLANE_U:
-            if (!dc_only_blk) {
-#if CONFIG_IST
-              av1_xform(x, AOM_PLANE_U, block, blk_row, blk_col, plane_bsize,
-                        &txfm_param, 1);
-              av1_xform(x, AOM_PLANE_V, block, blk_row, blk_col, plane_bsize,
-                        &txfm_param, 1);
-#else
-              av1_xform(x, AOM_PLANE_U, block, blk_row, blk_col, plane_bsize,
-                        &txfm_param);
-              av1_xform(x, AOM_PLANE_V, block, blk_row, blk_col, plane_bsize,
-                        &txfm_param);
-#endif
-            } else {
-              av1_xform_dc_only(x, AOM_PLANE_U, block, &txfm_param,
-                                per_px_mean);
-              av1_xform_dc_only(x, AOM_PLANE_V, block, &txfm_param,
-                                per_px_mean);
-            }
-            forward_cross_chroma_transform(x, block, txfm_param.tx_size);
-            break;
-          case AOM_PLANE_V: break;
-        }
-      } else {
-#endif
-        if (!dc_only_blk)
-#if CONFIG_IST
-          av1_xform(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
-                    1);
+        av1_xform(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
+                  1);
 #else
       av1_xform(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param);
 #endif
-        else
-          av1_xform_dc_only(x, plane, block, &txfm_param, per_px_mean);
-#if CONFIG_CROSS_CHROMA_TX
-      }
-#endif  // CONFIG_CROSS_CHROMA_TX
+      else
+        av1_xform_dc_only(x, plane, block, &txfm_param, per_px_mean);
 
 #if CONFIG_IST
       skip_trellis_based_on_satd[txfm_param.tx_type] =
@@ -2700,15 +2798,21 @@
 #endif  // CONFIG_FORWARDSKIP
       // Calculate rate cost of quantized coefficients.
       if (quant_param.use_optimize_b) {
-        av1_optimize_b(cpi, x, plane, block, tx_size, tx_type, txb_ctx,
-                       &rate_cost);
+        av1_optimize_b(cpi, x, plane, block, tx_size, tx_type,
+#if CONFIG_CROSS_CHROMA_TX
+                       CCTX_NONE,
+#endif  // CONFIG_CROSS_CHROMA_TX
+                       txb_ctx, &rate_cost);
       } else {
         rate_cost = cost_coeffs(
 #if CONFIG_FORWARDSKIP
             cm,
 #endif  // CONFIG_FORWARDSKIP
-            x, plane, block, tx_size, tx_type, txb_ctx,
-            cm->features.reduced_tx_set_used);
+            x, plane, block, tx_size, tx_type,
+#if CONFIG_CROSS_CHROMA_TX
+            CCTX_NONE,
+#endif  // CONFIG_CROSS_CHROMA_TX
+            txb_ctx, cm->features.reduced_tx_set_used);
       }
 
       // If rd cost based on coeff rate alone is already more than best_rd,
@@ -2898,6 +3002,211 @@
   p->dqcoeff = orig_dqcoeff;
 }
 
+#if CONFIG_CROSS_CHROMA_TX
+// Search for the best CCTX type for a given transform block.
+static void search_cctx_type(const AV1_COMP *cpi, MACROBLOCK *x, int block,
+                             int blk_row, int blk_col, BLOCK_SIZE plane_bsize,
+                             TX_SIZE tx_size, const TXB_CTX *const txb_ctx_uv,
+                             const int skip_trellis, RD_STATS *best_rd_stats) {
+  const AV1_COMMON *cm = &cpi->common;
+  MACROBLOCKD *xd = &x->e_mbd;
+  MB_MODE_INFO *mbmi = xd->mi[0];
+  struct macroblock_plane *const p_u = &x->plane[AOM_PLANE_U];
+  struct macroblock_plane *const p_v = &x->plane[AOM_PLANE_V];
+
+  int64_t best_rd = RDCOST(x->rdmult, best_rd_stats->rate, best_rd_stats->dist);
+  uint16_t best_eob_u = p_u->eobs[block];
+  uint16_t best_eob_v = p_v->eobs[block];
+  CctxType best_cctx_type = CCTX_NONE;
+  TX_TYPE tx_type =
+      av1_get_tx_type(xd, PLANE_TYPE_UV, blk_row, blk_col, tx_size,
+                      cpi->common.features.reduced_tx_set_used);
+
+  int rate_cost[2] = { 0, 0 };
+
+  // The buffer used to swap dqcoeff in macroblockd_plane so we can keep dqcoeff
+  // of the best tx_type.
+  DECLARE_ALIGNED(32, tran_low_t, this_dqcoeff_u[MAX_TX_SQUARE]);
+  DECLARE_ALIGNED(32, tran_low_t, this_dqcoeff_v[MAX_TX_SQUARE]);
+  tran_low_t *orig_dqcoeff_u = p_u->dqcoeff;
+  tran_low_t *orig_dqcoeff_v = p_v->dqcoeff;
+  tran_low_t *best_dqcoeff_u = this_dqcoeff_u;
+  tran_low_t *best_dqcoeff_v = this_dqcoeff_v;
+
+  uint16_t *eobs_ptr_u = x->plane[AOM_PLANE_U].eobs;
+  uint16_t *eobs_ptr_v = x->plane[AOM_PLANE_V].eobs;
+  uint8_t best_txb_ctx_u = 0;
+  uint8_t best_txb_ctx_v = 0;
+
+  TxfmParam txfm_param;
+  av1_setup_xform(cm, x,
+#if CONFIG_IST
+                  AOM_PLANE_U,
+#endif
+                  tx_size, tx_type, CCTX_NONE, &txfm_param);
+
+  // CCTX is performed in-place, so these buffers are needed to store original
+  // transform coefficients.
+  const int max_eob = av1_get_max_eob(tx_size);
+  DECLARE_ALIGNED(32, tran_low_t, orig_coeff_u[MAX_TX_SQUARE]);
+  DECLARE_ALIGNED(32, tran_low_t, orig_coeff_v[MAX_TX_SQUARE]);
+  memcpy(orig_coeff_u, p_u->coeff + BLOCK_OFFSET(block),
+         sizeof(tran_low_t) * max_eob);
+  memcpy(orig_coeff_v, p_v->coeff + BLOCK_OFFSET(block),
+         sizeof(tran_low_t) * max_eob);
+
+  // Iterate through all transform type candidates.
+  for (CctxType cctx_type = CCTX_START; cctx_type < CCTX_TYPES; ++cctx_type) {
+    RD_STATS this_rd_stats;
+    av1_invalid_rd_stats(&this_rd_stats);
+
+    update_cctx_array(xd, blk_row, blk_col, tx_size, cctx_type);
+    forward_cross_chroma_transform(x, block, tx_size, cctx_type);
+
+    for (int plane = AOM_PLANE_U; plane <= AOM_PLANE_V; plane++) {
+#if CCTX_C2_DROPPED
+      if (plane == AOM_PLANE_V && !keep_chroma_c2(cctx_type)) {
+        memset(p_v->dqcoeff + BLOCK_OFFSET(block), 0,
+               max_eob * sizeof(p_v->dqcoeff));
+        eobs_ptr_v[block] = 0;
+        rate_cost[1] = 0;
+        break;
+      }
+#endif
+
+      QUANT_PARAM quant_param;
+      // TODO(kslu): need to search skip trellis?
+      int xform_quant_idx = skip_trellis
+                                ? (USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B
+                                                          : AV1_XFORM_QUANT_FP)
+                                : AV1_XFORM_QUANT_FP;
+      av1_setup_quant(tx_size, !skip_trellis, xform_quant_idx,
+                      cpi->oxcf.q_cfg.quant_b_adapt, &quant_param);
+
+      if (av1_use_qmatrix(&cm->quant_params, xd, mbmi->segment_id))
+        av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, tx_type,
+                          &quant_param);
+
+      av1_quant(x, plane, block, &txfm_param, &quant_param);
+
+      // Calculate rate cost of quantized coefficients.
+      if (quant_param.use_optimize_b) {
+        av1_optimize_b(cpi, x, plane, block, tx_size, tx_type, cctx_type,
+                       &txb_ctx_uv[plane - AOM_PLANE_U],
+                       &rate_cost[plane - AOM_PLANE_U]);
+      } else {
+        rate_cost[plane - AOM_PLANE_U] = cost_coeffs(
+#if CONFIG_FORWARDSKIP
+            cm,
+#endif  // CONFIG_FORWARDSKIP
+            x, plane, block, tx_size, tx_type, cctx_type,
+            &txb_ctx_uv[plane - AOM_PLANE_U], cm->features.reduced_tx_set_used);
+      }
+    }
+#if CCTX_C1_NONZERO
+    // TODO(kslu) for negative angles, skip av1_xform_quant and reuse previous
+    // dqcoeffs
+    uint64_t sse_dqcoeff_u =
+        aom_sum_squares_i16((int16_t *)p_u->dqcoeff, (uint32_t)max_eob);
+    uint64_t sse_dqcoeff_v =
+        aom_sum_squares_i16((int16_t *)p_v->dqcoeff, (uint32_t)max_eob);
+    // Disallow the case where C1 eob is zero in cctx
+    if (eobs_ptr_u[block] == 0 || sse_dqcoeff_v > sse_dqcoeff_u) {
+      // Recover the original transform coefficients
+      if (cctx_type < CCTX_TYPES - 1) {
+        memcpy(p_u->coeff + BLOCK_OFFSET(block), orig_coeff_u,
+               sizeof(tran_low_t) * max_eob);
+        memcpy(p_v->coeff + BLOCK_OFFSET(block), orig_coeff_v,
+               sizeof(tran_low_t) * max_eob);
+      }
+      continue;
+    }
+#endif
+
+    // If rd cost based on coeff rate alone is already more than best_rd,
+    // terminate early.
+    if (RDCOST(x->rdmult, rate_cost[0] + rate_cost[1], 0) > best_rd) continue;
+
+    // Calculate distortion.
+    if (eobs_ptr_u[block] == 0 && eobs_ptr_v[block] == 0) {
+      // When eob is 0, pixel domain distortion is more efficient and accurate.
+      this_rd_stats.dist = this_rd_stats.sse = best_rd_stats->sse;
+    } else {
+      this_rd_stats.dist = joint_uv_dist_block_px_domain(
+          cpi, x, plane_bsize, block, blk_row, blk_col, tx_size);
+      this_rd_stats.sse = best_rd_stats->sse;
+    }
+
+    this_rd_stats.rate = rate_cost[0] + rate_cost[1];
+
+    const int64_t rd =
+        RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
+
+    if (rd < best_rd) {
+      best_rd = rd;
+      *best_rd_stats = this_rd_stats;
+      best_cctx_type = cctx_type;
+      best_txb_ctx_u = p_u->txb_entropy_ctx[block];
+      best_txb_ctx_v = p_v->txb_entropy_ctx[block];
+      best_eob_u = p_u->eobs[block];
+      best_eob_v = p_v->eobs[block];
+      // Swap dqcoeff buffers
+      tran_low_t *const tmp_dqcoeff_u = best_dqcoeff_u;
+      tran_low_t *const tmp_dqcoeff_v = best_dqcoeff_v;
+      best_dqcoeff_u = p_u->dqcoeff;
+      best_dqcoeff_v = p_v->dqcoeff;
+      p_u->dqcoeff = tmp_dqcoeff_u;
+      p_v->dqcoeff = tmp_dqcoeff_v;
+    }
+
+    // Recover the original transform coefficients
+    if (cctx_type < CCTX_TYPES - 1) {
+      memcpy(p_u->coeff + BLOCK_OFFSET(block), orig_coeff_u,
+             sizeof(tran_low_t) * max_eob);
+      memcpy(p_v->coeff + BLOCK_OFFSET(block), orig_coeff_v,
+             sizeof(tran_low_t) * max_eob);
+    }
+  }
+
+  assert(best_rd != INT64_MAX);
+
+  best_rd_stats->skip_txfm = (best_eob_u == 0 && best_eob_v == 0);
+  update_cctx_array(xd, blk_row, blk_col, tx_size, best_cctx_type);
+  p_u->txb_entropy_ctx[block] = best_txb_ctx_u;
+  p_v->txb_entropy_ctx[block] = best_txb_ctx_v;
+  p_u->eobs[block] = best_eob_u;
+  p_v->eobs[block] = best_eob_v;
+
+#if CCTX_C1_NONZERO
+  assert(IMPLIES(best_cctx_type > CCTX_NONE, best_eob_u > 0));
+#endif
+#if CCTX_C2_DROPPED
+  assert(IMPLIES(!keep_chroma_c2(best_cctx_type), best_eob_v == 0));
+#endif
+
+#if CCTX_INTRA
+  // Point dqcoeff to the quantized coefficients corresponding to the best
+  // transform type, then we can skip transform and quantization, e.g. in the
+  // final pixel domain distortion calculation and recon_intra().
+  p_u->dqcoeff = best_dqcoeff_u;
+  p_v->dqcoeff = best_dqcoeff_v;
+
+  // TODO(kslu) double check the removal of calc_pixel_domain_distortion_final
+
+  // Intra mode needs decoded pixels such that the next transform block can use
+  // them for prediction.
+  recon_intra(cpi, x, AOM_PLANE_U, block, blk_row, blk_col, plane_bsize,
+              tx_size, &txb_ctx_uv[0], skip_trellis, tx_type, 0, &rate_cost[0],
+              AOMMAX(best_eob_u, best_eob_v));
+  recon_intra(cpi, x, AOM_PLANE_V, block, blk_row, blk_col, plane_bsize,
+              tx_size, &txb_ctx_uv[1], skip_trellis, tx_type, 0, &rate_cost[1],
+              AOMMAX(best_eob_u, best_eob_v));
+#endif  // CCTX_INTRA
+  p_u->dqcoeff = orig_dqcoeff_u;
+  p_v->dqcoeff = orig_dqcoeff_v;
+}
+#endif  // CONFIG_CROSS_CHROMA_TX
+
 // Pick transform type for a luma transform block of tx_size. Note this function
 // is used only for inter-predicted blocks.
 static AOM_INLINE void tx_type_rd(const AV1_COMP *cpi, MACROBLOCK *x,
@@ -3960,6 +4269,136 @@
 }
 #endif  // !CONFIG_NEW_TX_PARTITION
 
+#if CONFIG_CROSS_CHROMA_TX
+static AOM_INLINE void block_rd_txfm_joint_uv(int dummy_plane, int block,
+                                              int blk_row, int blk_col,
+                                              BLOCK_SIZE plane_bsize,
+                                              TX_SIZE tx_size, void *arg) {
+  (void)dummy_plane;
+  struct rdcost_block_args *args = arg;
+  MACROBLOCK *const x = args->x;
+  MACROBLOCKD *const xd = &x->e_mbd;
+  const int is_inter = is_inter_block(xd->mi[0], xd->tree_type);
+
+  const AV1_COMP *cpi = args->cpi;
+  const AV1_COMMON *cm = &cpi->common;
+  RD_STATS rd_stats_joint_uv;
+  av1_init_rd_stats(&rd_stats_joint_uv);
+  update_cctx_array(xd, blk_row, blk_col, tx_size, CCTX_NONE);
+
+  // Obtain RD cost for CCTX_NONE
+  RD_STATS rd_stats_uv[2];
+  av1_init_rd_stats(&rd_stats_uv[0]);
+  av1_init_rd_stats(&rd_stats_uv[1]);
+  TXB_CTX txb_ctx_uv[2];
+  for (int plane = AOM_PLANE_U; plane <= AOM_PLANE_V; ++plane) {
+    RD_STATS *this_rd_stats = &rd_stats_uv[plane - AOM_PLANE_U];
+    TXB_CTX *txb_ctx = &txb_ctx_uv[plane - AOM_PLANE_U];
+
+    // TODO(kslu): maybe remove this feature
+    if (args->exit_early) args->incomplete_exit = 1;
+
+    if (!is_inter) {
+      av1_predict_intra_block_facade(cm, xd, plane, blk_col, blk_row, tx_size);
+      av1_subtract_txb(x, plane, plane_bsize, blk_col, blk_row, tx_size);
+    }
+
+    const struct macroblockd_plane *const pd = &xd->plane[plane];
+    av1_get_entropy_contexts(plane_bsize, pd, args->t_above, args->t_left);
+
+    ENTROPY_CONTEXT *a = args->t_above + blk_col;
+    ENTROPY_CONTEXT *l = args->t_left + blk_row;
+
+    get_txb_ctx(plane_bsize, tx_size, plane, a, l, txb_ctx
+#if CONFIG_FORWARDSKIP
+                ,
+                xd->mi[0]->fsc_mode[xd->tree_type == CHROMA_PART]
+#endif  // CONFIG_FORWARDSKIP
+    );
+    search_tx_type(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
+                   txb_ctx, args->ftxs_mode, args->skip_trellis,
+                   args->best_rd - args->current_rd, this_rd_stats);
+#if CONFIG_FORWARDSKIP
+    if (this_rd_stats->dist == INT64_MAX) {
+      args->exit_early = 1;
+      args->incomplete_exit = 1;
+    }
+#endif  // CONFIG_FORWARDSKIP
+
+#if CONFIG_RD_DEBUG
+    update_txb_coeff_cost(this_rd_stats, plane, tx_size, blk_row, blk_col,
+                          this_rd_stats->rate);
+#endif  // CONFIG_RD_DEBUG
+    av1_set_txb_context(x, plane, block, tx_size, a, l);
+
+    const int blk_idx =
+        blk_row * (block_size_wide[plane_bsize] >> MI_SIZE_LOG2) + blk_col;
+    TxfmSearchInfo *txfm_info = &x->txfm_search_info;
+    set_blk_skip(txfm_info->blk_skip, plane, blk_idx, 0);
+
+    int64_t rd;
+    if (is_inter) {
+      const int64_t no_skip_txfm_rd =
+          RDCOST(x->rdmult, this_rd_stats->rate, this_rd_stats->dist);
+      const int64_t skip_txfm_rd = RDCOST(x->rdmult, 0, this_rd_stats->sse);
+      rd = AOMMIN(no_skip_txfm_rd, skip_txfm_rd);
+      this_rd_stats->skip_txfm &= !x->plane[plane].eobs[block];
+    } else {
+      // Signal non-skip_txfm for Intra blocks
+      rd = RDCOST(x->rdmult, this_rd_stats->rate, this_rd_stats->dist);
+      this_rd_stats->skip_txfm = 0;
+    }
+
+    args->current_rd += rd;
+    av1_merge_rd_stats(&rd_stats_joint_uv, this_rd_stats);
+  }
+
+  if (!rd_stats_uv[0].skip_txfm || !rd_stats_uv[1].skip_txfm) {
+    search_cctx_type(cpi, x, block, blk_row, blk_col, plane_bsize, tx_size,
+                     txb_ctx_uv, args->skip_trellis, &rd_stats_joint_uv);
+  }
+  av1_merge_rd_stats(&args->rd_stats, &rd_stats_joint_uv);
+}
+
+void av1_txfm_rd_joint_uv(MACROBLOCK *x, const AV1_COMP *cpi,
+                          RD_STATS *rd_stats, int64_t ref_best_rd,
+                          int64_t current_rd, BLOCK_SIZE plane_bsize,
+                          TX_SIZE tx_size, FAST_TX_SEARCH_MODE ftxs_mode,
+                          int skip_trellis) {
+  if (!cpi->oxcf.txfm_cfg.enable_tx64 &&
+      txsize_sqr_up_map[tx_size] == TX_64X64) {
+    av1_invalid_rd_stats(rd_stats);
+    return;
+  }
+
+  MACROBLOCKD *const xd = &x->e_mbd;
+  struct rdcost_block_args args;
+  av1_zero(args);
+  args.x = x;
+  args.cpi = cpi;
+  args.best_rd = ref_best_rd;
+  args.current_rd = current_rd;
+  args.ftxs_mode = ftxs_mode;
+  args.skip_trellis = skip_trellis;
+  av1_init_rd_stats(&args.rd_stats);
+
+  // Note: this only works when subsampling_x and subsampling_y are the same
+  // for U and V
+  av1_foreach_transformed_block_in_plane(xd, plane_bsize, AOM_PLANE_U,
+                                         block_rd_txfm_joint_uv, &args);
+
+  MB_MODE_INFO *const mbmi = xd->mi[0];
+  const int is_inter = is_inter_block(mbmi, xd->tree_type);
+  const int invalid_rd = is_inter ? args.incomplete_exit : args.exit_early;
+
+  if (invalid_rd) {
+    av1_invalid_rd_stats(rd_stats);
+  } else {
+    *rd_stats = args.rd_stats;
+  }
+}
+#endif  // CONFIG_CROSS_CHROMA_TX
+
 // Search for the best transform size and type for current inter-predicted
 // luma block with recursive transform block partitioning. The obtained
 // transform selection will be saved in xd->mi[0], the corresponding RD stats
@@ -4232,6 +4671,11 @@
     return;
   }
 
+#if CONFIG_CROSS_CHROMA_TX
+  const int n4 = bsize_to_num_blk(bs);
+  memset(xd->cctx_type_map, CCTX_NONE, sizeof(xd->cctx_type_map[0]) * n4);
+#endif  // CONFIG_CROSS_CHROMA_TX
+
   if (xd->lossless[mbmi->segment_id]) {
     // Lossless mode can only pick the smallest (4x4) transform size.
     choose_smallest_tx_size(cpi, x, rd_stats, ref_best_rd, bs);
@@ -4269,31 +4713,54 @@
   const int skip_trellis = 0;
   const TX_SIZE uv_tx_size = av1_get_tx_size(AOM_PLANE_U, xd);
   int is_cost_valid = 1;
-  for (int plane = 1; plane < MAX_MB_PLANE; ++plane) {
+#if CONFIG_CROSS_CHROMA_TX
+  if ((is_inter && CCTX_INTER) || (!is_inter && CCTX_INTRA)) {
+    // TODO(kslu): apply the early exit mechanism?
     RD_STATS this_rd_stats;
     int64_t chroma_ref_best_rd = ref_best_rd;
-    // For inter blocks, refined ref_best_rd is used for early exit
-    // For intra blocks, even though current rd crosses ref_best_rd, early
-    // exit is not recommended as current rd is used for gating subsequent
-    // modes as well (say, for angular modes)
-    // TODO(any): Extend the early exit mechanism for intra modes as well
     if (cpi->sf.inter_sf.perform_best_rd_based_gating_for_chroma && is_inter &&
         chroma_ref_best_rd != INT64_MAX)
       chroma_ref_best_rd = ref_best_rd - AOMMIN(this_rd, skip_txfm_rd);
-    av1_txfm_rd_in_plane(x, cpi, &this_rd_stats, chroma_ref_best_rd, 0, plane,
+    av1_txfm_rd_joint_uv(x, cpi, &this_rd_stats, chroma_ref_best_rd, 0,
                          plane_bsize, uv_tx_size, FTXS_NONE, skip_trellis);
     if (this_rd_stats.rate == INT_MAX) {
       is_cost_valid = 0;
-      break;
+    } else {
+      av1_merge_rd_stats(rd_stats, &this_rd_stats);
+      this_rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
+      skip_txfm_rd = RDCOST(x->rdmult, 0, rd_stats->sse);
+      if (AOMMIN(this_rd, skip_txfm_rd) > ref_best_rd) is_cost_valid = 0;
     }
-    av1_merge_rd_stats(rd_stats, &this_rd_stats);
-    this_rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
-    skip_txfm_rd = RDCOST(x->rdmult, 0, rd_stats->sse);
-    if (AOMMIN(this_rd, skip_txfm_rd) > ref_best_rd) {
-      is_cost_valid = 0;
-      break;
+  } else {
+#endif  // CONFIG_CROSS_CHROMA_TX
+    for (int plane = 1; plane < MAX_MB_PLANE; ++plane) {
+      RD_STATS this_rd_stats;
+      int64_t chroma_ref_best_rd = ref_best_rd;
+      // For inter blocks, refined ref_best_rd is used for early exit
+      // For intra blocks, even though current rd crosses ref_best_rd, early
+      // exit is not recommended as current rd is used for gating subsequent
+      // modes as well (say, for angular modes)
+      // TODO(any): Extend the early exit mechanism for intra modes as well
+      if (cpi->sf.inter_sf.perform_best_rd_based_gating_for_chroma &&
+          is_inter && chroma_ref_best_rd != INT64_MAX)
+        chroma_ref_best_rd = ref_best_rd - AOMMIN(this_rd, skip_txfm_rd);
+      av1_txfm_rd_in_plane(x, cpi, &this_rd_stats, chroma_ref_best_rd, 0, plane,
+                           plane_bsize, uv_tx_size, FTXS_NONE, skip_trellis);
+      if (this_rd_stats.rate == INT_MAX) {
+        is_cost_valid = 0;
+        break;
+      }
+      av1_merge_rd_stats(rd_stats, &this_rd_stats);
+      this_rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
+      skip_txfm_rd = RDCOST(x->rdmult, 0, rd_stats->sse);
+      if (AOMMIN(this_rd, skip_txfm_rd) > ref_best_rd) {
+        is_cost_valid = 0;
+        break;
+      }
     }
+#if CONFIG_CROSS_CHROMA_TX
   }
+#endif  // CONFIG_CROSS_CHROMA_TX
 
   if (!is_cost_valid) {
     // reset cost value
diff --git a/tools/aom_entropy_optimizer.c b/tools/aom_entropy_optimizer.c
index ce10657..a72a843 100644
--- a/tools/aom_entropy_optimizer.c
+++ b/tools/aom_entropy_optimizer.c
@@ -372,6 +372,15 @@
                      "[CDF_SIZE(DDT_TYPES)]");
 #endif  // CONFIG_DDT_INTER
 
+#if CONFIG_CROSS_CHROMA_TX
+  /* cctx type */
+  cts_each_dim[0] = EXT_TX_SIZES;
+  cts_each_dim[1] = CCTX_TYPES;
+  optimize_cdf_table(&fc.cctx_type[0][0], probsfile, 2, cts_each_dim,
+                     "static const aom_cdf_prob default_cctx_type[EXT_TX_SIZES]"
+                     "[CDF_SIZE(CCTX_TYPES)]");
+#endif  // CONFIG_CROSS_CHROMA_TX
+
   /* tx type */
   cts_each_dim[0] = EXT_TX_SETS_INTRA;
   cts_each_dim[1] = EXT_TX_SIZES;