Cross chroma transform-add contexts and optimize CDFs

Introduce top and left cctx_type as contexts, optimization of CDFs, and misc cleanups.
diff --git a/av1/common/av1_common_int.h b/av1/common/av1_common_int.h
index 06123ba..cfded4d 100644
--- a/av1/common/av1_common_int.h
+++ b/av1/common/av1_common_int.h
@@ -2230,10 +2230,12 @@
   // 'xd->tx_type_map' should point to an offset in 'mi_params->tx_type_map'.
   if (xd->tree_type != CHROMA_PART) {
     xd->tx_type_map = mi_params->tx_type_map + mi_grid_idx;
-#if CONFIG_CROSS_CHROMA_TX
-    xd->cctx_type_map = mi_params->cctx_type_map + mi_grid_idx;
-#endif  // CONFIG_CROSS_CHROMA_TX
   }
+#if CONFIG_CROSS_CHROMA_TX
+  if (xd->tree_type != LUMA_PART) {
+    xd->cctx_type_map = mi_params->cctx_type_map + mi_grid_idx;
+  }
+#endif  // CONFIG_CROSS_CHROMA_TX
   xd->tx_type_map_stride = mi_params->mi_stride;
 }
 
diff --git a/av1/common/av1_txfm.c b/av1/common/av1_txfm.c
index 267c59e..6c2e2c9 100644
--- a/av1/common/av1_txfm.c
+++ b/av1/common/av1_txfm.c
@@ -202,54 +202,34 @@
 // stores two values: cos(t) and sin(t) for each rotation angle.
 const int32_t cctx_mtx[CCTX_TYPES - 1][2] = {
 #if CCTX_ANGLE_CONFIG == 0
-#if CCTX_POS_ANGLES
-  { 181, 181 },  // t = 45 degrees
-  { 222, 128 },  // t = 30 degrees
-  { 128, 222 },  // t = 60 degrees
-#endif           // CCTX_POS_ANGLES
-#if CCTX_NEG_ANGLES
+  { 181, 181 },   // t = 45 degrees
+  { 222, 128 },   // t = 30 degrees
+  { 128, 222 },   // t = 60 degrees
   { 181, -181 },  // t = -45 degrees
   { 222, -128 },  // t = -30 degrees
   { 128, -222 },  // t = -60 degrees
-#endif            // CCTX_NEG_ANGLES
 #elif CCTX_ANGLE_CONFIG == 1
-#if CCTX_POS_ANGLES
   { 181, 181 },   // t = 45 degrees
   { 236, 98 },    // t = 22.5 degrees
   { 98, 236 },    // t = 67.5 degrees
-#endif  // CCTX_POS_ANGLES
-#if CCTX_NEG_ANGLES
   { 181, -181 },  // t = -45 degrees
   { 236, -98 },   // t = -22.5 degrees
   { 98, -236 },   // t = -67.5 degrees
-#endif  // CCTX_NEG_ANGLES
 #elif CCTX_ANGLE_CONFIG == 2
-#if CCTX_POS_ANGLES
   { 181, 181 },   // t = 45 degrees
   { 232, 108 },   // t = 25 degrees
   { 108, 232 },   // t = 65 degrees
-#endif  // CCTX_POS_ANGLES
-#if CCTX_NEG_ANGLES
   { 181, -181 },  // t = -45 degrees
   { 232, -108 },  // t = -25 degrees
   { 108, -232 },  // t = -65 degrees
-#endif  // CCTX_NEG_ANGLES
 #elif CCTX_ANGLE_CONFIG == 3
-#if CCTX_POS_ANGLES
   { 222, 128 },   // t = 30 degrees
   { 128, 222 },   // t = 60 degrees
-#endif  // CCTX_POS_ANGLES
-#if CCTX_NEG_ANGLES
   { 222, -128 },  // t = -30 degrees
   { 128, -222 },  // t = -60 degrees
-#endif  // CCTX_NEG_ANGLES
 #elif CCTX_ANGLE_CONFIG == 4
-#if CCTX_POS_ANGLES
   { 181, 181 },   // t = 45 degrees
-#endif  // CCTX_POS_ANGLES
-#if CCTX_NEG_ANGLES
   { 181, -181 },  // t = -45 degrees
-#endif  // CCTX_NEG_ANGLES
 #endif
 };
 #endif  // CONFIG_CROSS_CHROMA_TX
diff --git a/av1/common/blockd.h b/av1/common/blockd.h
index aee81be..9f22162 100644
--- a/av1/common/blockd.h
+++ b/av1/common/blockd.h
@@ -1564,38 +1564,56 @@
 #if CONFIG_CROSS_CHROMA_TX
 #if CCTX_C2_DROPPED
 static INLINE int keep_chroma_c2(CctxType cctx_type) {
-  return
-#if CCTX_NEG_ANGLES
-      cctx_type == CCTX_M30 || cctx_type == CCTX_M60 ||
-#endif
-#if CCTX_POS_ANGLES
-      cctx_type == CCTX_30 || cctx_type == CCTX_60 ||
-#endif
-      cctx_type == CCTX_NONE;
+  return cctx_type == CCTX_M30 || cctx_type == CCTX_M60 ||
+         cctx_type == CCTX_30 || cctx_type == CCTX_60 || cctx_type == CCTX_NONE;
 }
 #endif
 
+// When the current block is sub 8x8, obtain amounts of offset to its parent
+// 8x8 block. Otherwise set the offsets to 0.
+static INLINE void get_offsets_to_8x8(MACROBLOCKD *const xd, TX_SIZE tx_size,
+                                      int *row_offset, int *col_offset) {
+  const struct macroblockd_plane *const pd = &xd->plane[AOM_PLANE_U];
+  const int ss_x = pd->subsampling_x;
+  const int ss_y = pd->subsampling_y;
+  *row_offset =
+      (xd->mi_row & 0x01) && (tx_size_high_unit[tx_size] & 0x01) && ss_y;
+  *col_offset =
+      (xd->mi_col & 0x01) && (tx_size_wide_unit[tx_size] & 0x01) && ss_x;
+}
+
 static INLINE void update_cctx_array(MACROBLOCKD *const xd, int blk_row,
-                                     int blk_col, TX_SIZE tx_size,
+                                     int blk_col, int blk_row_offset,
+                                     int blk_col_offset, TX_SIZE tx_size,
                                      CctxType cctx_type) {
   const int stride = xd->tx_type_map_stride;
-  xd->cctx_type_map[blk_row * stride + blk_col] = cctx_type;
+  const struct macroblockd_plane *const pd = &xd->plane[AOM_PLANE_U];
+  const int ss_x = pd->subsampling_x;
+  const int ss_y = pd->subsampling_y;
+  assert(xd->is_chroma_ref);
 
-  const int txw = tx_size_wide_unit[tx_size];
-  const int txh = tx_size_high_unit[tx_size];
-  // The 16x16 unit is due to the constraint from tx_64x64 which sets the
-  // maximum tx size for chroma as 32x32. Coupled with 4x1 transform block
-  // size, the constraint takes effect in 32x16 / 16x32 size too. To solve
-  // the intricacy, cover all the 16x16 units inside a 64 level transform.
-  if (txw == tx_size_wide_unit[TX_64X64] ||
-      txh == tx_size_high_unit[TX_64X64]) {
-    const int tx_unit = tx_size_wide_unit[TX_16X16];
-    for (int idy = 0; idy < txh; idy += tx_unit) {
-      for (int idx = 0; idx < txw; idx += tx_unit) {
-        xd->cctx_type_map[(blk_row + idy) * stride + blk_col + idx] = cctx_type;
-      }
-    }
-  }
+  // For sub 8x8 block, offsets will be applied to reach the mi_row and mi_col
+  // of the >= 8x8 block area. Transform block size is upscaled to match the
+  // luma block size.
+  const int br = (blk_row << ss_y) - blk_row_offset;
+  const int bc = (blk_col << ss_x) - blk_col_offset;
+  const int txw = tx_size_wide_unit[tx_size] << ss_x;
+  const int txh = tx_size_high_unit[tx_size] << ss_y;
+
+  // To make cctx_type available for its right and bottom neighbors, cover
+  // all elements in cctx_type_map within the transform block range with the
+  // current cctx type
+  for (int idy = 0; idy < txh; idy++)
+    memset(&xd->cctx_type_map[(br + idy) * stride + bc], cctx_type,
+           txw * sizeof(xd->cctx_type_map[0]));
+}
+
+static INLINE CctxType av1_get_cctx_type(const MACROBLOCKD *xd, int blk_row,
+                                         int blk_col) {
+  const struct macroblockd_plane *const pd = &xd->plane[AOM_PLANE_U];
+  const int br = blk_row << pd->subsampling_y;
+  const int bc = blk_col << pd->subsampling_x;
+  return xd->cctx_type_map[br * xd->tx_type_map_stride + bc];
 }
 #endif  // CONFIG_CROSS_CHROMA_TX
 
@@ -1777,13 +1795,6 @@
   return tx_type;
 }
 
-#if CONFIG_CROSS_CHROMA_TX
-static INLINE CctxType av1_get_cctx_type(const MACROBLOCKD *xd, int blk_row,
-                                         int blk_col) {
-  return xd->cctx_type_map[blk_row * xd->tx_type_map_stride + blk_col];
-}
-#endif  // CONFIG_CROSS_CHROMA_TX
-
 void av1_setup_block_planes(MACROBLOCKD *xd, int ss_x, int ss_y,
                             const int num_planes);
 
diff --git a/av1/common/entropy.c b/av1/common/entropy.c
index 86b6465..b6933d5 100644
--- a/av1/common/entropy.c
+++ b/av1/common/entropy.c
@@ -296,6 +296,6 @@
   RESET_CDF_COUNTER_STRIDE(fc->stx_cdf, STX_TYPES, CDF_SIZE(STX_TYPES));
 #endif
 #if CONFIG_CROSS_CHROMA_TX
-  RESET_CDF_COUNTER(fc->cctx_type_cdf, CCTX_TYPES);
+  RESET_CDF_COUNTER(fc->cctx_type_cdf, CCTX_TYPES_ALLOWED);
 #endif  // CONFIG_CROSS_CHROMA_TX
 }
diff --git a/av1/common/entropymode.c b/av1/common/entropymode.c
index e48fea4..f00070d 100644
--- a/av1/common/entropymode.c
+++ b/av1/common/entropymode.c
@@ -684,62 +684,77 @@
 #endif  // CONFIG_DDT_INTER
 
 #if CONFIG_CROSS_CHROMA_TX
-static const aom_cdf_prob default_cctx_type_cdf[EXT_TX_SIZES]
-                                               [CDF_SIZE(CCTX_TYPES)] = {
+static const aom_cdf_prob
+    default_cctx_type_cdf[EXT_TX_SIZES][CCTX_CONTEXTS]
+                         [CDF_SIZE(CCTX_TYPES_ALLOWED)] = {
 #if CCTX_ANGLE_CONFIG == 4
-#if CCTX_POS_ANGLES && CCTX_NEG_ANGLES
-                                                 { AOM_CDF3(10923, 21845) },
-                                                 { AOM_CDF3(10923, 21845) },
-                                                 { AOM_CDF3(10923, 21845) },
-                                                 { AOM_CDF3(10923, 21845) },
-#else
-                                                 { AOM_CDF2(16384) },
-                                                 { AOM_CDF2(16384) },
-                                                 { AOM_CDF2(16384) },
-                                                 { AOM_CDF2(16384) },
-#endif  // CCTX_POS_ANGLES && CCTX_NEG_ANGLES
+                           { { AOM_CDF3(10923, 21845) },
+                             { AOM_CDF3(10923, 21845) },
+                             { AOM_CDF3(10923, 21845) } },
+                           { { AOM_CDF3(10923, 21845) },
+                             { AOM_CDF3(10923, 21845) },
+                             { AOM_CDF3(10923, 21845) } },
+                           { { AOM_CDF3(10923, 21845) },
+                             { AOM_CDF3(10923, 21845) },
+                             { AOM_CDF3(10923, 21845) } },
+                           { { AOM_CDF3(10923, 21845) },
+                             { AOM_CDF3(10923, 21845) },
+                             { AOM_CDF3(10923, 21845) } },
 #elif CCTX_ANGLE_CONFIG == 3
-#if CCTX_POS_ANGLES && CCTX_NEG_ANGLES
-                                                 { AOM_CDF5(6554, 13107, 19661,
-                                                            26214) },
-                                                 { AOM_CDF5(6554, 13107, 19661,
-                                                            26214) },
-                                                 { AOM_CDF5(6554, 13107, 19661,
-                                                            26214) },
-                                                 { AOM_CDF5(6554, 13107, 19661,
-                                                            26214) },
+                           { { AOM_CDF5(6554, 13107, 19661, 26214) },
+                             { AOM_CDF5(6554, 13107, 19661, 26214) },
+                             { AOM_CDF5(6554, 13107, 19661, 26214) } },
+                           { { AOM_CDF5(6554, 13107, 19661, 26214) },
+                             { AOM_CDF5(6554, 13107, 19661, 26214) },
+                             { AOM_CDF5(6554, 13107, 19661, 26214) } },
+                           { { AOM_CDF5(6554, 13107, 19661, 26214) },
+                             { AOM_CDF5(6554, 13107, 19661, 26214) },
+                             { AOM_CDF5(6554, 13107, 19661, 26214) } },
+                           { { AOM_CDF5(6554, 13107, 19661, 26214) },
+                             { AOM_CDF5(6554, 13107, 19661, 26214) },
+                             { AOM_CDF5(6554, 13107, 19661, 26214) } },
 #else
-                                                 { AOM_CDF3(10923, 21845) },
-                                                 { AOM_CDF3(10923, 21845) },
-                                                 { AOM_CDF3(10923, 21845) },
-                                                 { AOM_CDF3(10923, 21845) },
-#endif  // CCTX_POS_ANGLES && CCTX_NEG_ANGLES
+#if CCTX_ADAPT_REDUCED_SET
+                           { { AOM_CDF3(21092, 22824) },
+                             { AOM_CDF3(23150, 29387) },
+                             { AOM_CDF3(12029, 25649) } },
+                           { { AOM_CDF3(17069, 19010) },
+                             { AOM_CDF3(22591, 27810) },
+                             { AOM_CDF3(10803, 24548) } },
+                           { { AOM_CDF3(16209, 18900) },
+                             { AOM_CDF3(21703, 27555) },
+                             { AOM_CDF3(8605, 23810) } },
+                           { { AOM_CDF3(15354, 17963) },
+                             { AOM_CDF3(22686, 27727) },
+                             { AOM_CDF3(9173, 22932) } }
 #else
-#if CCTX_POS_ANGLES && CCTX_NEG_ANGLES
-                                                 { AOM_CDF7(4681, 9362, 14043,
-                                                            18725, 23406,
-                                                            28087) },
-                                                 { AOM_CDF7(4681, 9362, 14043,
-                                                            18725, 23406,
-                                                            28087) },
-                                                 { AOM_CDF7(4681, 9362, 14043,
-                                                            18725, 23406,
-                                                            28087) },
-                                                 { AOM_CDF7(4681, 9362, 14043,
-                                                            18725, 23406,
-                                                            28087) },
-#else
-                                                 { AOM_CDF4(8192, 16384,
-                                                            24576) },
-                                                 { AOM_CDF4(8192, 16384,
-                                                            24576) },
-                                                 { AOM_CDF4(8192, 16384,
-                                                            24576) },
-                                                 { AOM_CDF4(8192, 16384,
-                                                            24576) },
-#endif  // CCTX_POS_ANGLES && CCTX_NEG_ANGLES
+                           { { AOM_CDF7(19143, 19642, 20876, 21362, 23684,
+                                        30645) },
+                             { AOM_CDF7(15852, 17519, 22430, 24276, 26473,
+                                        30362) },
+                             { AOM_CDF7(9981, 10351, 11021, 11340, 16893,
+                                        28901) } },
+                           { { AOM_CDF7(13312, 14068, 15345, 16249, 20082,
+                                        29648) },
+                             { AOM_CDF7(11802, 14635, 17918, 20493, 23927,
+                                        29206) },
+                             { AOM_CDF7(8348, 8915, 9727, 10347, 16584,
+                                        27923) } },
+                           { { AOM_CDF7(10604, 11887, 13486, 14485, 19798,
+                                        28529) },
+                             { AOM_CDF7(10790, 13346, 16867, 18854, 23398,
+                                        29133) },
+                             { AOM_CDF7(6538, 7104, 7997, 8723, 15658,
+                                        26864) } },
+                           { { AOM_CDF7(13226, 13959, 14918, 15707, 21009,
+                                        29328) },
+                             { AOM_CDF7(10336, 13195, 15614, 17813, 21992,
+                                        29469) },
+                             { AOM_CDF7(7769, 8772, 9617, 10150, 16729,
+                                        28132) } }
 #endif
-                                               };
+#endif
+                         };
 #endif  // CONFIG_CROSS_CHROMA_TX
 
 static const aom_cdf_prob default_cfl_sign_cdf[CDF_SIZE(CFL_JOINT_SIGNS)] = {
diff --git a/av1/common/entropymode.h b/av1/common/entropymode.h
index c2267a0..af9a7f2 100644
--- a/av1/common/entropymode.h
+++ b/av1/common/entropymode.h
@@ -273,7 +273,8 @@
   aom_cdf_prob stx_cdf[TX_SIZES][CDF_SIZE(STX_TYPES)];
 #endif
 #if CONFIG_CROSS_CHROMA_TX
-  aom_cdf_prob cctx_type_cdf[EXT_TX_SIZES][CDF_SIZE(CCTX_TYPES)];
+  aom_cdf_prob cctx_type_cdf[EXT_TX_SIZES][CCTX_CONTEXTS]
+                            [CDF_SIZE(CCTX_TYPES_ALLOWED)];
 #endif  // CONFIG_CROSS_CHROMA_TX
   int initialized;
 } FRAME_CONTEXT;
diff --git a/av1/common/enums.h b/av1/common/enums.h
index 9f33c0c..aef609b 100644
--- a/av1/common/enums.h
+++ b/av1/common/enums.h
@@ -404,14 +404,9 @@
 } UENUM1BYTE(TX_TYPE);
 
 #if CONFIG_CROSS_CHROMA_TX
-#define CCTX_NEG_ANGLES 1
-#define CCTX_POS_ANGLES 1
-// Always signal C1 coefficients for some cctx (i.e., both C1 and C2 nonzero
-// or C1 nonzero and C2 zero). This requires CCTX_NEG_ANGLES to be on.
-#define CCTX_C1_NONZERO 1
-// Drop C2 channel for some cctx_types. This macro requires CCTX_C1_NONZERO to
-// be on.
+#define CCTX_CONTEXTS 3
 #define CCTX_C2_DROPPED 0
+#define CCTX_ADAPT_REDUCED_SET 0
 // Configuration for the set of rotation angles
 // 0: { 45, 30, 60 }
 // 1: { 45, 22.5, 67.5 }
@@ -421,7 +416,6 @@
 #define CCTX_ANGLE_CONFIG 0
 enum {
   CCTX_NONE,  // No cross chroma transform
-#if CCTX_POS_ANGLES
 #if CCTX_ANGLE_CONFIG != 3
   CCTX_45,  // 45 degrees rotation (Haar transform)
 #endif
@@ -429,8 +423,6 @@
   CCTX_30,  // 30 degrees rotation
   CCTX_60,  // 60 degrees rotation
 #endif
-#endif  // CCTX_POS_ANGLES
-#if CCTX_NEG_ANGLES
 #if CCTX_ANGLE_CONFIG != 3
   CCTX_M45,  // -45 degrees rotation
 #endif
@@ -438,10 +430,16 @@
   CCTX_M30,  // -30 degrees rotation
   CCTX_M60,  // -60 degrees rotation
 #endif
-#endif  // CCTX_NEG_ANGLES
   CCTX_TYPES,
   CCTX_START = CCTX_NONE + 1,
 } UENUM1BYTE(CctxType);
+
+#if CCTX_ADAPT_REDUCED_SET
+#define CCTX_TYPES_ALLOWED 3
+#else
+#define CCTX_TYPES_ALLOWED CCTX_TYPES
+#endif
+
 #endif  // CONFIG_CROSS_CHROMA_TX
 
 enum {
diff --git a/av1/common/pred_common.h b/av1/common/pred_common.h
index 3491536..b614df3 100644
--- a/av1/common/pred_common.h
+++ b/av1/common/pred_common.h
@@ -357,6 +357,145 @@
 #endif  // CONFIG_SKIP_MODE_ENHANCEMENT
 }
 
+#if CONFIG_CROSS_CHROMA_TX
+#if CCTX_ADAPT_REDUCED_SET
+// The closest nonzero neighboring cctx type of the current cctx type
+static const CctxType closest_nonzero_cctx[CCTX_TYPES] = {
+#if 1
+  CCTX_30, CCTX_30, CCTX_45, CCTX_30, CCTX_M30, CCTX_M45, CCTX_M30
+#else
+  CCTX_30, CCTX_30, CCTX_45, CCTX_45, CCTX_M30, CCTX_M45, CCTX_M45
+#endif
+};
+
+// Return the set of 3 allowed cctx types given the above and left cctx types.
+// Since CCTX_NONE will always be allowed, so we add the above and left
+// to the allowed list only when they are valid (not -1) and not CCTX_NONE.
+// Then we add their closest cctx types if there is any available slot.
+static INLINE uint8_t get_allowed_cctx_mask(int above, int left) {
+  if (above <= CCTX_NONE && left <= CCTX_NONE)
+    return (1 << CCTX_NONE) + (1 << CCTX_30) + (1 << CCTX_M30);
+  else if (above <= CCTX_NONE)
+    return (1 << CCTX_NONE) + (1 << left) + (1 << closest_nonzero_cctx[left]);
+  else if (left <= CCTX_NONE || above == left)
+    return (1 << CCTX_NONE) + (1 << above) + (1 << closest_nonzero_cctx[above]);
+  else
+    return (1 << CCTX_NONE) + (1 << above) + (1 << left);
+}
+
+static INLINE void get_allowed_cctx_arr(const int above, const int left,
+                                        CctxType *cctxarr) {
+  cctxarr[0] = CCTX_NONE;
+  if (above <= CCTX_NONE && left <= CCTX_NONE) {
+    cctxarr[1] = CCTX_30;
+    cctxarr[2] = CCTX_M30;
+  } else if (above <= CCTX_NONE) {
+    cctxarr[1] = left;
+    cctxarr[2] = closest_nonzero_cctx[left];
+  } else if (left <= CCTX_NONE || above == left) {
+    cctxarr[1] = above;
+    cctxarr[2] = closest_nonzero_cctx[above];
+  } else {
+    cctxarr[1] = above;
+    cctxarr[2] = left;
+  }
+}
+
+static INLINE CctxType cctx_idx_to_type(const int cctx_idx, const int above,
+                                        const int left) {
+  CctxType cctx_arr[CCTX_TYPES_ALLOWED] = { 0 };
+  get_allowed_cctx_arr(above, left, cctx_arr);
+  return cctx_arr[cctx_idx];
+}
+
+static INLINE uint8_t cctx_type_to_idx(const CctxType ctype, const int above,
+                                       const int left) {
+  CctxType cctx_arr[CCTX_TYPES_ALLOWED] = { 0 };
+  get_allowed_cctx_arr(above, left, cctx_arr);
+  if (ctype == cctx_arr[0]) return 0;
+  if (ctype == cctx_arr[1]) return 1;
+  if (ctype == cctx_arr[2]) return 2;
+  assert(0);
+  return 0;
+}
+#endif
+// TODO(kslu) remove it
+// static INLINE void get_above_and_left_cctx_type(const MACROBLOCKD *xd,
+//                                                int blk_row, int blk_col,
+//                                                TX_SIZE tx_size,
+//                                                int *above_cctx,
+//                                                int *left_cctx) {
+//  const int ss_x = xd->plane[AOM_PLANE_U].subsampling_x;
+//  const int ss_y = xd->plane[AOM_PLANE_U].subsampling_y;
+//  const int txh = tx_size_high_unit[tx_size];
+//  const int txw = tx_size_wide_unit[tx_size];
+//
+//  // Offsets are needed for sub 8x8 blocks to reach the top left corner of the
+//  // current block where the current cctx_type is applied
+//  const int mi_row_offset = (xd->mi_row & 0x01) && (txh & 0x01) && ss_y;
+//  const int mi_col_offset = (xd->mi_col & 0x01) && (txw & 0x01) && ss_x;
+//  const int stride = xd->tx_type_map_stride;
+//  CctxType *cur_cctx_ptr =
+//      &xd->cctx_type_map[((blk_row << ss_y) - mi_row_offset) * stride +
+//                         (blk_col << ss_x) - mi_col_offset];
+//
+//  *above_cctx = xd->chroma_up_available ? (int)cur_cctx_ptr[-stride] : -1;
+//  *left_cctx = xd->chroma_left_available ? (int)cur_cctx_ptr[-1] : -1;
+//  assert(*above_cctx >= -1 && *above_cctx < CCTX_TYPES);
+//  assert(*left_cctx >= -1 && *left_cctx < CCTX_TYPES);
+//}
+
+static INLINE void get_above_and_left_cctx_type(
+    const AV1_COMMON *cm, const MACROBLOCKD *xd, int blk_row, int blk_col,
+    TX_SIZE tx_size, int *above_cctx, int *left_cctx) {
+  const int ss_x = xd->plane[AOM_PLANE_U].subsampling_x;
+  const int ss_y = xd->plane[AOM_PLANE_U].subsampling_y;
+  const int txh = tx_size_high_unit[tx_size];
+  const int txw = tx_size_wide_unit[tx_size];
+
+  const CommonModeInfoParams *const mi_params = &cm->mi_params;
+  const int stride = mi_params->mi_stride;
+
+  // Offsets are needed for sub 8x8 blocks to reach the top left corner of the
+  // current block where the current cctx_type is applied
+  const int mi_row_offset = (xd->mi_row & 0x01) && (txh & 0x01) && ss_y;
+  const int mi_col_offset = (xd->mi_col & 0x01) && (txw & 0x01) && ss_x;
+  const int mi_grid_idx = get_mi_grid_idx(mi_params, xd->mi_row - mi_row_offset,
+                                          xd->mi_col - mi_col_offset);
+  CctxType *const cur_cctx_ptr = mi_params->cctx_type_map + mi_grid_idx;
+
+  // TODO(kslu) change this workaround for shifts
+  const int cctx_stride = xd->tx_type_map_stride;
+  const int br = (txw == (cctx_stride >> ss_x)) ? blk_row : (blk_row << ss_y);
+  const int bc = (txw == (cctx_stride >> ss_x)) ? blk_col : (blk_col << ss_x);
+  if (blk_row)
+    *above_cctx = (int)xd->cctx_type_map[(br - 1) * cctx_stride + bc];
+  else
+    *above_cctx = xd->chroma_up_available ? (int)cur_cctx_ptr[-stride] : -1;
+
+  if (blk_col)
+    *left_cctx = (int)xd->cctx_type_map[br * stride + bc - 1];
+  else
+    *left_cctx = xd->chroma_left_available ? (int)cur_cctx_ptr[-1] : -1;
+
+  assert(*above_cctx >= -1 && *above_cctx < CCTX_TYPES);
+  assert(*left_cctx >= -1 && *left_cctx < CCTX_TYPES);
+}
+
+// 0: CCTX_NONE, unequal top and left context, or unavailable context
+// 1: positive angle cctx
+// 2: negative angle cctx
+static INLINE int get_cctx_context(const MACROBLOCKD *xd, const int above,
+                                   const int left) {
+  int above_ctx =
+      xd->chroma_up_available ? ((above > CCTX_60) + (above > CCTX_NONE)) : 0;
+  int left_ctx =
+      xd->chroma_left_available ? ((left > CCTX_60) + (left > CCTX_NONE)) : 0;
+  if (above_ctx == 0 || left_ctx == 0) return AOMMAX(above_ctx, left_ctx);
+  return (above_ctx == left_ctx) ? above_ctx : 0;
+}
+#endif  // CONFIG_CROSS_CHROMA_TX
+
 int av1_get_pred_context_switchable_interp(const MACROBLOCKD *xd, int dir);
 
 // Get a list of palette base colors that are used in the above and left blocks,
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index ba3490a..a1190be 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -251,14 +251,8 @@
 #if CONFIG_CROSS_CHROMA_TX && CCTX_INTRA
     eob_info *eob_data_u =
         dcb->eob_data[AOM_PLANE_U] + dcb->txb_offset[AOM_PLANE_U];
-#if CCTX_C1_NONZERO
-    if (eob_data->eob || (av1_get_cctx_type(xd, row, col) > CCTX_NONE &&
-                          plane == AOM_PLANE_V && eob_data_u->eob)) {
-#else
-    eob_info *eob_data_v =
-        dcb->eob_data[AOM_PLANE_V] + dcb->txb_offset[AOM_PLANE_V];
-    if (eob_data->eob || (plane && (eob_data_u->eob || eob_data_v->eob))) {
-#endif
+    if (eob_data->eob || (plane == AOM_PLANE_V && eob_data_u->eob &&
+                          av1_get_cctx_type(xd, row, col) > CCTX_NONE)) {
 #else
     if (eob_data->eob) {
 #endif  // CONFIG_CROSS_CHROMA_TX
@@ -1319,6 +1313,15 @@
                                       blk_row, blk_col, block, max_tx_size,
                                       &eobtotal);
                 block += stepr * stepc;
+#if CONFIG_CROSS_CHROMA_TX
+              } else if (plane == AOM_PLANE_U) {
+                // fill cctx_type_map with CCTX_NONE for skip blocks so their
+                // neighbors can derive cctx contexts
+                int row_offset, col_offset;
+                get_offsets_to_8x8(xd, max_tx_size, &row_offset, &col_offset);
+                update_cctx_array(xd, blk_row, blk_col, row_offset, col_offset,
+                                  max_tx_size, CCTX_NONE);
+#endif
               }
             }
           }
diff --git a/av1/decoder/decodemv.c b/av1/decoder/decodemv.c
index 5181fcd..d6b9e07 100644
--- a/av1/decoder/decodemv.c
+++ b/av1/decoder/decodemv.c
@@ -874,9 +874,13 @@
                         int blk_row, int blk_col, TX_SIZE tx_size,
                         aom_reader *r) {
   MB_MODE_INFO *mbmi = xd->mi[0];
-  CctxType *cctx_type =
-      &xd->cctx_type_map[blk_row * xd->tx_type_map_stride + blk_col];
-  *cctx_type = CCTX_NONE;
+  // If it is a sub 8x8 chroma block, derive the mi_row and mi_col of the
+  // parent block area. Then apply cctx type update to this area w.r.t the
+  // offsets derived
+  int row_offset, col_offset;
+  get_offsets_to_8x8(xd, tx_size, &row_offset, &col_offset);
+  update_cctx_array(xd, blk_row, blk_col, row_offset, col_offset, tx_size,
+                    CCTX_NONE);
 
   // No need to read transform type if block is skipped.
   if (mbmi->skip_txfm[xd->tree_type == CHROMA_PART] ||
@@ -888,11 +892,26 @@
   if (qindex == 0) return;
 
   const int is_inter = is_inter_block(mbmi, xd->tree_type);
+  CctxType cctx_type = CCTX_NONE;
   if ((is_inter && CCTX_INTER) || (!is_inter && CCTX_INTRA)) {
     FRAME_CONTEXT *ec_ctx = xd->tile_ctx;
     const TX_SIZE square_tx_size = txsize_sqr_map[tx_size];
-    *cctx_type = aom_read_symbol(r, ec_ctx->cctx_type_cdf[square_tx_size],
-                                 CCTX_TYPES, ACCT_STR);
+    int above_cctx, left_cctx;
+    get_above_and_left_cctx_type(cm, xd, blk_row, blk_col, tx_size, &above_cctx,
+                                 &left_cctx);
+    const int cctx_ctx = get_cctx_context(xd, above_cctx, left_cctx);
+#if CCTX_ADAPT_REDUCED_SET
+    const int cctx_idx =
+        aom_read_symbol(r, ec_ctx->cctx_type_cdf[square_tx_size][cctx_ctx],
+                        CCTX_TYPES_ALLOWED, ACCT_STR);
+    cctx_type = cctx_idx_to_type(cctx_idx, above_cctx, left_cctx);
+#else
+    cctx_type =
+        aom_read_symbol(r, ec_ctx->cctx_type_cdf[square_tx_size][cctx_ctx],
+                        CCTX_TYPES_ALLOWED, ACCT_STR);
+#endif
+    update_cctx_array(xd, blk_row, blk_col, row_offset, col_offset, tx_size,
+                      cctx_type);
   }
 }
 #endif  // CONFIG_CROSS_CHROMA_TX
diff --git a/av1/decoder/decoder.c b/av1/decoder/decoder.c
index 0558a17..a77ba3e 100644
--- a/av1/decoder/decoder.c
+++ b/av1/decoder/decoder.c
@@ -103,6 +103,10 @@
       mi_params->mi_stride * calc_mi_size(mi_params->mi_rows);
   memset(mi_params->mi_grid_base, 0,
          mi_grid_size * sizeof(*mi_params->mi_grid_base));
+#if CONFIG_CROSS_CHROMA_TX
+  memset(mi_params->cctx_type_map, 0,
+         mi_grid_size * sizeof(*mi_params->cctx_type_map));
+#endif  // CONFIG_CROSS_CHROMA_TX
 }
 
 static void dec_free_mi(CommonModeInfoParams *mi_params) {
diff --git a/av1/decoder/decodetxb.c b/av1/decoder/decodetxb.c
index 779408f..cc178c7 100644
--- a/av1/decoder/decodetxb.c
+++ b/av1/decoder/decodetxb.c
@@ -14,6 +14,7 @@
 
 #include "aom_ports/mem.h"
 #include "av1/common/idct.h"
+#include "av1/common/pred_common.h"
 #include "av1/common/scan.h"
 #include "av1/common/txb_common.h"
 #if CONFIG_FORWARDSKIP
@@ -192,23 +193,16 @@
 #endif  // CONFIG_CONTEXT_DERIVATION
 
 #if CONFIG_CROSS_CHROMA_TX
-#if CCTX_C1_NONZERO
   if (plane == AOM_PLANE_U) {
-    if (!all_zero)
+    if (!all_zero) {
       av1_read_cctx_type(cm, xd, blk_row, blk_col, tx_size, r);
-    else
-      xd->cctx_type_map[blk_row * xd->tx_type_map_stride + blk_col] = CCTX_NONE;
+    } else {
+      int row_offset, col_offset;
+      get_offsets_to_8x8(xd, tx_size, &row_offset, &col_offset);
+      update_cctx_array(xd, blk_row, blk_col, row_offset, col_offset, tx_size,
+                        CCTX_NONE);
+    }
   }
-#else
-  if (plane == AOM_PLANE_V) {
-    // cctx_type will be read either eob_v > 0 or eob_u > 0
-    eob_info *eob_data_u =
-        dcb->eob_data[AOM_PLANE_U] + dcb->txb_offset[AOM_PLANE_U];
-    const uint16_t eob_u = eob_data_u->eob;
-    if (!all_zero || eob_u > 0)
-      av1_read_cctx_type(cm, xd, blk_row, blk_col, tx_size, r);
-  }
-#endif  // CCTX_C1_NONZERO
 #endif  // CONFIG_CROSS_CHROMA_TX
 
   if (all_zero) {
@@ -363,19 +357,16 @@
 #endif  // CONFIG_CONTEXT_DERIVATION
 
 #if CONFIG_CROSS_CHROMA_TX
-#if CCTX_C1_NONZERO
-  if (plane == AOM_PLANE_U && !all_zero) {
-    av1_read_cctx_type(cm, xd, blk_row, blk_col, tx_size, r);
-  }
-#else
-  if (plane == AOM_PLANE_V) {
-    eob_info *eob_data_u =
-        dcb->eob_data[AOM_PLANE_U] + dcb->txb_offset[AOM_PLANE_U];
-    uint16_t eob_u = eob_data_u->eob;
-    if (!all_zero || eob_u > 0)
+  if (plane == AOM_PLANE_U) {
+    if (!all_zero) {
       av1_read_cctx_type(cm, xd, blk_row, blk_col, tx_size, r);
+    } else {
+      int row_offset, col_offset;
+      get_offsets_to_8x8(xd, tx_size, &row_offset, &col_offset);
+      update_cctx_array(xd, blk_row, blk_col, row_offset, col_offset, tx_size,
+                        CCTX_NONE);
+    }
   }
-#endif  // CCTX_C1_NONZERO
 #endif  // CONFIG_CROSS_CHROMA_TX
 #endif  // CONFIG_FORWARDSKIP
   eob_info *eob_data = dcb->eob_data[plane] + dcb->txb_offset[plane];
@@ -722,9 +713,6 @@
         for (int idy = 0; idy < txh; idy += tx_unit) {
           for (int idx = 0; idx < txw; idx += tx_unit) {
             xd->tx_type_map[(row + idy) * stride + col + idx] = tx_type;
-#if CONFIG_CROSS_CHROMA_TX
-            xd->cctx_type_map[(row + idy) * stride + col + idx] = CCTX_NONE;
-#endif  // CONFIG_CROSS_CHROMA_TX
           }
         }
       }
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index c062195..6653955 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -1160,7 +1160,8 @@
 
 #if CONFIG_CROSS_CHROMA_TX
 void av1_write_cctx_type(const AV1_COMMON *const cm, const MACROBLOCKD *xd,
-                         CctxType cctx_type, TX_SIZE tx_size, aom_writer *w) {
+                         int blk_row, int blk_col, CctxType cctx_type,
+                         TX_SIZE tx_size, aom_writer *w) {
   MB_MODE_INFO *mbmi = xd->mi[0];
   assert(xd->is_chroma_ref);
   const int is_inter = is_inter_block(mbmi, xd->tree_type);
@@ -1171,8 +1172,19 @@
       !segfeature_active(&cm->seg, mbmi->segment_id, SEG_LVL_SKIP)) {
     FRAME_CONTEXT *ec_ctx = xd->tile_ctx;
     const TX_SIZE square_tx_size = txsize_sqr_map[tx_size];
-    aom_write_symbol(w, cctx_type, ec_ctx->cctx_type_cdf[square_tx_size],
-                     CCTX_TYPES);
+    int above_cctx, left_cctx;
+    get_above_and_left_cctx_type(cm, xd, blk_row, blk_col, tx_size, &above_cctx,
+                                 &left_cctx);
+    const int cctx_ctx = get_cctx_context(xd, above_cctx, left_cctx);
+#if CCTX_ADAPT_REDUCED_SET
+    aom_write_symbol(w, cctx_type_to_idx(cctx_type, above_cctx, left_cctx),
+                     ec_ctx->cctx_type_cdf[square_tx_size][cctx_ctx],
+                     CCTX_TYPES_ALLOWED);
+#else
+    aom_write_symbol(w, cctx_type,
+                     ec_ctx->cctx_type_cdf[square_tx_size][cctx_ctx],
+                     CCTX_TYPES_ALLOWED);
+#endif
   }
 }
 #endif  // CONFIG_CROSS_CHROMA_TX
@@ -2186,6 +2198,23 @@
   set_mi_row_col(xd, tile, mi_row, bh, mi_col, bw, mi_params->mi_rows,
                  mi_params->mi_cols);
 
+#if CONFIG_CROSS_CHROMA_TX
+  // For skip blocks, reset the corresponding area in cctx_type_map to
+  // CCTX_NONE, which will be used as contexts for later blocks. No need to use
+  // av1_get_adjusted_tx_size because uv_txsize is intended to cover the entire
+  // prediction block area
+  if (mbmi->skip_txfm[xd->tree_type == CHROMA_PART] &&
+      xd->tree_type != LUMA_PART && xd->is_chroma_ref) {
+    struct macroblockd_plane *const pd = &xd->plane[AOM_PLANE_U];
+    const BLOCK_SIZE uv_bsize =
+        get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y);
+    const TX_SIZE uv_txsize = max_txsize_rect_lookup[uv_bsize];
+    int row_offset, col_offset;
+    get_offsets_to_8x8(xd, uv_txsize, &row_offset, &col_offset);
+    update_cctx_array(xd, 0, 0, row_offset, col_offset, uv_txsize, CCTX_NONE);
+  }
+#endif  // CONFIG_CROSS_CHROMA_TX
+
   xd->above_txfm_context = cm->above_contexts.txfm[tile->tile_row] + mi_col;
   xd->left_txfm_context =
       xd->left_txfm_context_buffer + (mi_row & MAX_MIB_MASK);
diff --git a/av1/encoder/bitstream.h b/av1/encoder/bitstream.h
index 8a7cd65..e788253 100644
--- a/av1/encoder/bitstream.h
+++ b/av1/encoder/bitstream.h
@@ -55,7 +55,8 @@
 
 #if CONFIG_CROSS_CHROMA_TX
 void av1_write_cctx_type(const AV1_COMMON *const cm, const MACROBLOCKD *xd,
-                         CctxType cctx_type, TX_SIZE tx_size, aom_writer *w);
+                         int blk_row, int blk_col, CctxType cctx_type,
+                         TX_SIZE tx_size, aom_writer *w);
 #endif  // CONFIG_CROSS_CHROMA_TX
 
 #ifdef __cplusplus
diff --git a/av1/encoder/block.h b/av1/encoder/block.h
index 01f70d5..3abd50e 100644
--- a/av1/encoder/block.h
+++ b/av1/encoder/block.h
@@ -279,7 +279,7 @@
   TX_TYPE tx_type_map[MAX_MIB_SIZE * MAX_MIB_SIZE];
 #if CONFIG_CROSS_CHROMA_TX
   //! Map showing the cctx types for each block.
-  TX_TYPE cctx_type_map[MAX_MIB_SIZE * MAX_MIB_SIZE];
+  CctxType cctx_type_map[MAX_MIB_SIZE * MAX_MIB_SIZE];
 #endif  // CONFIG_CROSS_CHROMA_TX
   //! Rd_stats for the whole partition block.
   RD_STATS rd_stats;
@@ -855,7 +855,7 @@
 #endif  // CONFIG_DDT_INTER
 #if CONFIG_CROSS_CHROMA_TX
   //! cctx_type_cost
-  int cctx_type_cost[EXT_TX_SIZES][CCTX_TYPES];
+  int cctx_type_cost[EXT_TX_SIZES][CCTX_CONTEXTS][CCTX_TYPES_ALLOWED];
 #endif  // CONFIG_CROSS_CHROMA_TX
   /**@}*/
 
diff --git a/av1/encoder/encodeframe_utils.c b/av1/encoder/encodeframe_utils.c
index dcb753e..91a7073 100644
--- a/av1/encoder/encodeframe_utils.c
+++ b/av1/encoder/encodeframe_utils.c
@@ -247,18 +247,32 @@
   }
 
 #if CONFIG_CROSS_CHROMA_TX
-  if (xd->tree_type != LUMA_PART) {
+  if (xd->tree_type != LUMA_PART && xd->is_chroma_ref) {
     xd->cctx_type_map = ctx->cctx_type_map;
     xd->tx_type_map_stride = mi_size_wide[bsize];
+    // If this block is sub 8x8 in luma, derive the parent >= 8x8 block area,
+    // then update its corresponding chroma area in cctx_type_map to the
+    // current cctx type
+    const int ss_x = pd[AOM_PLANE_U].subsampling_x;
+    const int ss_y = pd[AOM_PLANE_U].subsampling_y;
+    const int mi_row_offset = (mi_row & 0x01) && (bh & 0x01) && ss_y;
+    const int mi_col_offset = (mi_col & 0x01) && (bw & 0x01) && ss_x;
+    const int grid_idx = get_mi_grid_idx(mi_params, mi_row - mi_row_offset,
+                                         mi_col - mi_col_offset);
+    CctxType *const cctx_type_map = mi_params->cctx_type_map + grid_idx;
+    const int mi_stride = mi_params->mi_stride;
+    const int is_inter = is_inter_block(mi_addr, xd->tree_type);
+    const int allow_cctx =
+        (is_inter && CCTX_INTER) || (!is_inter && CCTX_INTRA);
+    // Set cctx_type to CCTX_NONE when not allowed or for skip blocks
+    CctxType cur_cctx_type = (txfm_info->skip_txfm || !allow_cctx)
+                                 ? CCTX_NONE
+                                 : xd->cctx_type_map[0];
+    for (int blk_row = 0; blk_row < (mi_row_offset ? 2 : bh); ++blk_row) {
+      memset(&cctx_type_map[blk_row * mi_stride], cur_cctx_type,
+             (mi_col_offset ? 2 : bw) * sizeof(cctx_type_map[0]));
+    }
     if (!dry_run) {
-      const int grid_idx = get_mi_grid_idx(mi_params, mi_row, mi_col);
-      CctxType *const cctx_type_map = mi_params->cctx_type_map + grid_idx;
-      const int mi_stride = mi_params->mi_stride;
-      for (int blk_row = 0; blk_row < bh; ++blk_row) {
-        av1_copy_array(cctx_type_map + blk_row * mi_stride,
-                       xd->cctx_type_map + blk_row * xd->tx_type_map_stride,
-                       bw);
-      }
       xd->cctx_type_map = cctx_type_map;
       xd->tx_type_map_stride = mi_stride;
     }
@@ -1359,8 +1373,8 @@
                  CDF_SIZE(STX_TYPES));
 #endif
 #if CONFIG_CROSS_CHROMA_TX
-  AVG_CDF_STRIDE(ctx_left->cctx_type_cdf, ctx_tr->cctx_type_cdf, CCTX_TYPES,
-                 CDF_SIZE(CCTX_TYPES));
+  AVERAGE_CDF(ctx_left->cctx_type_cdf, ctx_tr->cctx_type_cdf,
+              CCTX_TYPES_ALLOWED);
 #endif  // CONFIG_CROSS_CHROMA_TX
 }
 
diff --git a/av1/encoder/encodemb.c b/av1/encoder/encodemb.c
index a2453b5..f3e4d49 100644
--- a/av1/encoder/encodemb.c
+++ b/av1/encoder/encodemb.c
@@ -78,7 +78,7 @@
 int av1_optimize_b(const struct AV1_COMP *cpi, MACROBLOCK *x, int plane,
                    int block, TX_SIZE tx_size, TX_TYPE tx_type,
 #if CONFIG_CROSS_CHROMA_TX
-                   CctxType cctx_type,
+                   CctxType cctx_type, int blk_row, int blk_col,
 #endif  // CONFIG_CROSS_CHROMA_TX
                    const TXB_CTX *const txb_ctx, int *rate_cost) {
   MACROBLOCKD *const xd = &x->e_mbd;
@@ -95,14 +95,15 @@
 #endif  // CONFIG_CONTEXT_DERIVATION
     );
 #if CONFIG_CROSS_CHROMA_TX
-    *rate_cost += get_cctx_type_cost(x, xd, plane, tx_size, block, cctx_type);
+    *rate_cost += get_cctx_type_cost(&cpi->common, x, xd, plane, tx_size,
+                                     blk_row, blk_col, block, cctx_type);
 #endif  // CONFIG_CROSS_CHROMA_TX
     return eob;
   }
 
   return av1_optimize_txb_new(cpi, x, plane, block, tx_size, tx_type,
 #if CONFIG_CROSS_CHROMA_TX
-                              cctx_type,
+                              cctx_type, blk_row, blk_col,
 #endif  // CONFIG_CROSS_CHROMA_TX
                               txb_ctx, rate_cost, cpi->oxcf.algo_cfg.sharpness);
 }
@@ -560,11 +561,26 @@
                                     tx_size, cm->features.reduced_tx_set_used);
 #if CONFIG_CROSS_CHROMA_TX
   CctxType cctx_type = av1_get_cctx_type(xd, blk_row, blk_col);
+#if CCTX_ADAPT_REDUCED_SET
+  if (plane) {
+    // TODO(kslu) change this workaround
+    // Since contexts can be changed during the dry run tx search, check if the
+    // cctx type is valid here. If not, just use CCTX_NONE.
+    int above_cctx, left_cctx;
+    get_above_and_left_cctx_type(cm, xd, blk_row, blk_col, tx_size, &above_cctx,
+                                 &left_cctx);
+    uint8_t allowed_cctx_mask = get_allowed_cctx_mask(above_cctx, left_cctx);
+    if (!(allowed_cctx_mask & (1 << cctx_type))) {
+      cctx_type = CCTX_NONE;
+      update_cctx_array(xd, blk_row, blk_col, 0, 0, tx_size, CCTX_NONE);
+    }
+  }
+#endif
 #endif  // CONFIG_CROSS_CHROMA_TX
 
   if (!is_blk_skip(x->txfm_search_info.blk_skip, plane,
                    blk_row * bw + blk_col) &&
-#if CONFIG_CROSS_CHROMA_TX && CCTX_INTER && CCTX_C1_NONZERO
+#if CONFIG_CROSS_CHROMA_TX && CCTX_INTER
 #if CCTX_C2_DROPPED
       (plane < AOM_PLANE_V ||
        ((cctx_type == CCTX_NONE || x->plane[AOM_PLANE_U].eobs[block]) &&
@@ -573,7 +589,7 @@
       (plane < AOM_PLANE_V || cctx_type == CCTX_NONE ||
        x->plane[AOM_PLANE_U].eobs[block]) &&
 #endif
-#endif  // CONFIG_CROSS_CHROMA_TX && CCTX_INTER && CCTX_C1_NONZERO
+#endif  // CONFIG_CROSS_CHROMA_TX && CCTX_INTER
 #if CONFIG_SKIP_MODE_ENHANCEMENT
       !(mbmi->skip_mode == 1)) {
 #else
@@ -635,7 +651,7 @@
       );
       av1_optimize_b(args->cpi, x, plane, block, tx_size, tx_type,
 #if CONFIG_CROSS_CHROMA_TX
-                     cctx_type,
+                     cctx_type, blk_row, blk_col,
 #endif  // CONFIG_CROSS_CHROMA_TX
                      &txb_ctx, &dummy_rate_cost);
     }
@@ -647,13 +663,12 @@
       av1_dropout_qcoeff(x, plane, block, tx_size, tx_type,
                          cm->quant_params.base_qindex);
     }
-#if CONFIG_CROSS_CHROMA_TX && CCTX_C1_NONZERO
+#if CONFIG_CROSS_CHROMA_TX
     // Since eob can be updated here, make sure cctx_type is always CCTX_NONE
     // when eob of U is 0.
-    // TODO(kslu) why cctx_type can be > CCTX_NONE when eob_u is 0?
     if (plane == AOM_PLANE_U && p->eobs[block] == 0)
-      update_cctx_array(xd, blk_row, blk_col, tx_size, CCTX_NONE);
-#endif  // CONFIG_CROSS_CHROMA_TX && CCTX_C1_NONZERO
+      update_cctx_array(xd, blk_row, blk_col, 0, 0, tx_size, CCTX_NONE);
+#endif  // CONFIG_CROSS_CHROMA_TX
   } else {
 #if CONFIG_CROSS_CHROMA_TX && CCTX_C2_DROPPED
     // Reset coeffs and dqcoeffs
@@ -1146,7 +1161,7 @@
       );
       av1_optimize_b(args->cpi, x, plane, block, tx_size, tx_type,
 #if CONFIG_CROSS_CHROMA_TX
-                     CCTX_NONE,
+                     CCTX_NONE, blk_row, blk_col,
 #endif  // CONFIG_CROSS_CHROMA_TX
                      &txb_ctx, &dummy_rate_cost);
     }
@@ -1246,6 +1261,19 @@
   TX_TYPE tx_type = av1_get_tx_type(xd, PLANE_TYPE_UV, blk_row, blk_col,
                                     tx_size, cm->features.reduced_tx_set_used);
   CctxType cctx_type = av1_get_cctx_type(xd, blk_row, blk_col);
+#if CCTX_ADAPT_REDUCED_SET
+  // TODO(kslu) change this workaround
+  // Since contexts can be changed during the dry run tx search, check if the
+  // cctx type is valid here. If not, just use CCTX_NONE.
+  int above_cctx, left_cctx;
+  get_above_and_left_cctx_type(cm, xd, blk_row, blk_col, tx_size, &above_cctx,
+                               &left_cctx);
+  uint8_t allowed_cctx_mask = get_allowed_cctx_mask(above_cctx, left_cctx);
+  if (!(allowed_cctx_mask & (1 << cctx_type))) {
+    cctx_type = CCTX_NONE;
+    update_cctx_array(xd, blk_row, blk_col, 0, 0, tx_size, CCTX_NONE);
+  }
+#endif
 
   av1_subtract_txb(x, AOM_PLANE_U, plane_bsize, blk_col, blk_row, tx_size);
   av1_subtract_txb(x, AOM_PLANE_V, plane_bsize, blk_col, blk_row, tx_size);
@@ -1283,22 +1311,20 @@
                             INTRA_BLOCK_OPT_TYPE == TRELLIS_DROPOUT_OPT));
 
   for (int plane = AOM_PLANE_U; plane <= AOM_PLANE_V; plane++) {
-#if CCTX_C1_NONZERO
+    // Since eob can be updated here, make sure cctx_type is always CCTX_NONE
+    // when eob of U is 0.
+    if (plane == AOM_PLANE_V && *eob_u == 0)
+      update_cctx_array(xd, blk_row, blk_col, 0, 0, tx_size, CCTX_NONE);
 #if CCTX_C2_DROPPED
     if (plane == AOM_PLANE_V && (!keep_chroma_c2(cctx_type) ||
                                  (*eob_u == 0 && cctx_type > CCTX_NONE))) {
 #else
     if (plane == AOM_PLANE_V && *eob_u == 0 && cctx_type > CCTX_NONE) {
 #endif
-      // Since eob can be updated here, make sure cctx_type is always CCTX_NONE
-      // when eob of U is 0.
-      if (*eob_u == 0 && cctx_type > CCTX_NONE)
-        update_cctx_array(xd, blk_row, blk_col, tx_size, CCTX_NONE);
       av1_quantize_skip(av1_get_max_eob(tx_size),
                         p_v->coeff + BLOCK_OFFSET(block), dqcoeff_v, eob_v);
       break;
     }
-#endif
     av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, tx_type,
                       &quant_param);
     av1_xform_quant(
@@ -1320,7 +1346,7 @@
 #endif  // CONFIG_FORWARDSKIP
       );
       av1_optimize_b(args->cpi, x, plane, block, tx_size, tx_type, cctx_type,
-                     &txb_ctx, &dummy_rate_cost);
+                     blk_row, blk_col, &txb_ctx, &dummy_rate_cost);
     }
     if (do_dropout) {
       av1_dropout_qcoeff(x, plane, block, tx_size, tx_type,
diff --git a/av1/encoder/encodemb.h b/av1/encoder/encodemb.h
index ade30ee..68dd9cc 100644
--- a/av1/encoder/encodemb.h
+++ b/av1/encoder/encodemb.h
@@ -124,7 +124,7 @@
 int av1_optimize_b(const struct AV1_COMP *cpi, MACROBLOCK *mb, int plane,
                    int block, TX_SIZE tx_size, TX_TYPE tx_type,
 #if CONFIG_CROSS_CHROMA_TX
-                   CctxType cctx_type,
+                   CctxType cctx_type, int blk_row, int blk_col,
 #endif  // CONFIG_CROSS_CHROMA_TX
                    const TXB_CTX *const txb_ctx, int *rate_cost);
 
diff --git a/av1/encoder/encoder.h b/av1/encoder/encoder.h
index 6d8cd79..0e1f301 100644
--- a/av1/encoder/encoder.h
+++ b/av1/encoder/encoder.h
@@ -1298,7 +1298,7 @@
                            [TX_TYPES];
 #endif  // CONFIG_DDT_INTER
 #if CONFIG_CROSS_CHROMA_TX
-  unsigned int cctx_type[EXT_TX_SIZES][CCTX_TYPES];
+  unsigned int cctx_type[EXT_TX_SIZES][CCTX_CONTEXTS][CCTX_TYPES_ALLOWED];
 #endif  // CONFIG_CROSS_CHROMA_TX
   unsigned int filter_intra_mode[FILTER_INTRA_MODES];
   unsigned int filter_intra[BLOCK_SIZES_ALL][2];
diff --git a/av1/encoder/encodetxb.c b/av1/encoder/encodetxb.c
index 4ca8a19..1157ad1 100644
--- a/av1/encoder/encodetxb.c
+++ b/av1/encoder/encodetxb.c
@@ -457,23 +457,11 @@
 #endif  // CONFIG_CONTEXT_DERIVATION
 
 #if CONFIG_CROSS_CHROMA_TX
-#if CCTX_C1_NONZERO
-  if (plane == AOM_PLANE_U && eob > 0) {
+  if (plane == AOM_PLANE_U) {
     CctxType cctx_type = av1_get_cctx_type(xd, blk_row, blk_col);
-    av1_write_cctx_type(cm, xd, cctx_type, tx_size, w);
+    if (eob > 0)
+      av1_write_cctx_type(cm, xd, blk_row, blk_col, cctx_type, tx_size, w);
   }
-#else
-  // tx_type is signaled with Y plane if eob > 0. cctx_type is signaled with V
-  // plane if either of eob_u and eob_v is > 0.
-  if (plane == AOM_PLANE_V) {
-    const uint16_t *eob_txb_u = cb_coef_buff->eobs[AOM_PLANE_U] + txb_offset;
-    const uint16_t eob_u = eob_txb_u[block];
-    if (eob > 0 || eob_u > 0) {
-      const CctxType cctx_type = av1_get_cctx_type(xd, blk_row, blk_col);
-      av1_write_cctx_type(cm, xd, cctx_type, tx_size, w);
-    }
-  }
-#endif  // CCTX_C1_NONZERO
 #endif  // CONFIG_CROSS_CHROMA_TX
 
   if (eob == 0) return 0;
@@ -600,22 +588,10 @@
 #endif  // CONFIG_CONTEXT_DERIVATION
 
 #if CONFIG_CROSS_CHROMA_TX
-#if CCTX_C1_NONZERO
   if (plane == AOM_PLANE_U && eob > 0) {
     CctxType cctx_type = av1_get_cctx_type(xd, blk_row, blk_col);
-    av1_write_cctx_type(cm, xd, cctx_type, tx_size, w);
+    av1_write_cctx_type(cm, xd, blk_row, blk_col, cctx_type, tx_size, w);
   }
-#else
-  // CCTX type is transmitted with V plane
-  if (plane == AOM_PLANE_V) {
-    const uint16_t *eob_txb_u = cb_coef_buff->eobs[AOM_PLANE_U] + txb_offset;
-    const uint16_t eob_u = eob_txb_u[block];
-    if (eob > 0 || eob_u > 0) {
-      const CctxType cctx_type = av1_get_cctx_type(xd, blk_row, blk_col);
-      av1_write_cctx_type(cm, xd, cctx_type, tx_size, w);
-    }
-  }
-#endif  // CCTX_C1_NONZERO
 #endif  // CONFIG_CROSS_CHROMA_TX
 #endif  // CONFIG_FORWARDSKIP
 
@@ -879,22 +855,27 @@
 }
 
 #if CONFIG_CROSS_CHROMA_TX
-int get_cctx_type_cost(const MACROBLOCK *x, const MACROBLOCKD *xd, int plane,
-                       TX_SIZE tx_size, int block, CctxType cctx_type) {
+int get_cctx_type_cost(const AV1_COMMON *cm, const MACROBLOCK *x,
+                       const MACROBLOCKD *xd, int plane, TX_SIZE tx_size,
+                       int blk_row, int blk_col, int block,
+                       CctxType cctx_type) {
   const int is_inter = is_inter_block(xd->mi[0], xd->tree_type);
-  const TX_SIZE square_tx_size = txsize_sqr_map[tx_size];
-#if CCTX_C1_NONZERO
-  (void)block;
   if (plane == AOM_PLANE_U && x->plane[plane].eobs[block] &&
+      ((is_inter && CCTX_INTER) || (!is_inter && CCTX_INTRA))) {
+    const TX_SIZE square_tx_size = txsize_sqr_map[tx_size];
+    int above_cctx, left_cctx;
+    get_above_and_left_cctx_type(cm, xd, blk_row, blk_col, tx_size, &above_cctx,
+                                 &left_cctx);
+    const int cctx_ctx = get_cctx_context(xd, above_cctx, left_cctx);
+#if CCTX_ADAPT_REDUCED_SET
+    const int cctx_idx = cctx_type_to_idx(cctx_type, above_cctx, left_cctx);
+    return x->mode_costs.cctx_type_cost[square_tx_size][cctx_ctx][cctx_idx];
 #else
-  if (plane == AOM_PLANE_V &&
-      (x->plane[AOM_PLANE_U].eobs[block] ||
-       x->plane[AOM_PLANE_V].eobs[block]) &&
-#endif  // CCTX_C1_NONZERO
-      ((is_inter && CCTX_INTER) || (!is_inter && CCTX_INTRA)))
-    return x->mode_costs.cctx_type_cost[square_tx_size][cctx_type];
-  else
+    return x->mode_costs.cctx_type_cost[square_tx_size][cctx_ctx][cctx_type];
+#endif
+  } else {
     return 0;
+  }
 }
 #endif  // CONFIG_CROSS_CHROMA_TX
 
@@ -1024,13 +1005,16 @@
 
 #if CONFIG_FORWARDSKIP
 static AOM_FORCE_INLINE int warehouse_efficients_txb_skip(
+#if CONFIG_CROSS_CHROMA_TX
+    const AV1_COMMON *cm,
+#endif  // CONFIG_CROSS_CHROMA_TX
     const MACROBLOCK *x, const int plane, const int block,
     const TX_SIZE tx_size, const TXB_CTX *const txb_ctx,
     const struct macroblock_plane *p, const int eob,
     const LV_MAP_COEFF_COST *const coeff_costs, const MACROBLOCKD *const xd,
     const TX_TYPE tx_type,
 #if CONFIG_CROSS_CHROMA_TX
-    const CctxType cctx_type,
+    const CctxType cctx_type, int blk_row, int blk_col,
 #endif  // CONFIG_CROSS_CHROMA_TX
     int reduced_tx_set_used) {
   const tran_low_t *const qcoeff = p->qcoeff + BLOCK_OFFSET(block);
@@ -1053,7 +1037,8 @@
 #endif  // CONFIG_IST
   );
 #if CONFIG_CROSS_CHROMA_TX
-  cost += get_cctx_type_cost(x, xd, plane, tx_size, block, cctx_type);
+  cost += get_cctx_type_cost(cm, x, xd, plane, tx_size, blk_row, blk_col, block,
+                             cctx_type);
 #endif  // CONFIG_CROSS_CHROMA_TX
   DECLARE_ALIGNED(16, int8_t, coeff_contexts[MAX_TX_SQUARE]);
   av1_get_nz_map_contexts_skip(levels, scan, eob, tx_size, coeff_contexts);
@@ -1088,13 +1073,16 @@
 #endif  // CONFIG_FORWARDSKIP
 
 static AOM_FORCE_INLINE int warehouse_efficients_txb(
+#if CONFIG_CROSS_CHROMA_TX
+    const AV1_COMMON *cm,
+#endif  // CONFIG_CROSS_CHROMA_TX
     const MACROBLOCK *x, const int plane, const int block,
     const TX_SIZE tx_size, const TXB_CTX *const txb_ctx,
     const struct macroblock_plane *p, const int eob,
     const PLANE_TYPE plane_type, const LV_MAP_COEFF_COST *const coeff_costs,
     const MACROBLOCKD *const xd, const TX_TYPE tx_type,
 #if CONFIG_CROSS_CHROMA_TX
-    const CctxType cctx_type,
+    const CctxType cctx_type, int blk_row, int blk_col,
 #endif  // CONFIG_CROSS_CHROMA_TX
     const TX_CLASS tx_class, int reduced_tx_set_used) {
   const tran_low_t *const qcoeff = p->qcoeff + BLOCK_OFFSET(block);
@@ -1138,7 +1126,8 @@
 #endif
   );
 #if CONFIG_CROSS_CHROMA_TX
-  cost += get_cctx_type_cost(x, xd, plane, tx_size, block, cctx_type);
+  cost += get_cctx_type_cost(cm, x, xd, plane, tx_size, blk_row, blk_col, block,
+                             cctx_type);
 #endif  // CONFIG_CROSS_CHROMA_TX
 
   cost += get_eob_cost(eob, eob_costs, coeff_costs, tx_class);
@@ -1249,15 +1238,15 @@
 }
 
 static AOM_FORCE_INLINE int warehouse_efficients_txb_laplacian(
-#if CONFIG_FORWARDSKIP
+#if CONFIG_FORWARDSKIP || CONFIG_CROSS_CHROMA_TX
     const AV1_COMMON *cm,
-#endif  // CONFIG_FORWARDSKIP
+#endif  // CONFIG_FORWARDSKIP || CONFIG_CROSS_CHROMA_TX
     const MACROBLOCK *x, const int plane, const int block,
     const TX_SIZE tx_size, const TXB_CTX *const txb_ctx, const int eob,
     const PLANE_TYPE plane_type, const LV_MAP_COEFF_COST *const coeff_costs,
     const MACROBLOCKD *const xd, const TX_TYPE tx_type,
 #if CONFIG_CROSS_CHROMA_TX
-    const CctxType cctx_type,
+    const CctxType cctx_type, int blk_row, int blk_col,
 #endif  // CONFIG_CROSS_CHROMA_TX
     const TX_CLASS tx_class, int reduced_tx_set_used) {
 #if CONFIG_CONTEXT_DERIVATION
@@ -1291,7 +1280,8 @@
 #endif  // CONFIG_IST
   );
 #if CONFIG_CROSS_CHROMA_TX
-  cost += get_cctx_type_cost(x, xd, plane, tx_size, block, cctx_type);
+  cost += get_cctx_type_cost(cm, x, xd, plane, tx_size, blk_row, blk_col, block,
+                             cctx_type);
 #endif  // CONFIG_CROSS_CHROMA_TX
 
 #if !CONFIG_FORWARDSKIP
@@ -1387,7 +1377,7 @@
 }
 #endif  // CONFIG_FORWARDSKIP
 
-#if CONFIG_FORWARDSKIP
+#if CONFIG_FORWARDSKIP || CONFIG_CROSS_CHROMA_TX
 int av1_cost_coeffs_txb(const AV1_COMMON *cm, const MACROBLOCK *x,
                         const int plane, const int block,
 #else
@@ -1395,7 +1385,7 @@
 #endif  // CONFIG_FORWARDSKIP
                         const TX_SIZE tx_size, const TX_TYPE tx_type,
 #if CONFIG_CROSS_CHROMA_TX
-                        const CctxType cctx_type,
+                        const CctxType cctx_type, int blk_row, int blk_col,
 #endif  // CONFIG_CROSS_CHROMA_TX
                         const TXB_CTX *const txb_ctx, int reduced_tx_set_used) {
   const struct macroblock_plane *p = &x->plane[plane];
@@ -1423,7 +1413,8 @@
     skip_cost += coeff_costs->txb_skip_cost[txb_ctx->txb_skip_ctx][1];
 #endif  // CONFIG_CONTEXT_DERIVATION
 #if CONFIG_CROSS_CHROMA_TX
-    skip_cost += get_cctx_type_cost(x, xd, plane, tx_size, block, cctx_type);
+    skip_cost += get_cctx_type_cost(cm, x, xd, plane, tx_size, blk_row, blk_col,
+                                    block, cctx_type);
 #endif  // CONFIG_CROSS_CHROMA_TX
     return skip_cost;
   }
@@ -1443,38 +1434,49 @@
        tx_type == IDTX && plane == PLANE_TYPE_Y) ||
 #endif  // CONFIG_IST
       use_inter_fsc(cm, plane, tx_type, is_inter_block(mbmi, xd->tree_type))) {
-    return warehouse_efficients_txb_skip(x, plane, block, tx_size, txb_ctx, p,
-                                         eob, coeff_costs, xd, tx_type,
+    return warehouse_efficients_txb_skip(
 #if CONFIG_CROSS_CHROMA_TX
-                                         cctx_type,
+        cm,
 #endif  // CONFIG_CROSS_CHROMA_TX
-                                         reduced_tx_set_used);
+        x, plane, block, tx_size, txb_ctx, p, eob, coeff_costs, xd, tx_type,
+#if CONFIG_CROSS_CHROMA_TX
+        cctx_type, blk_row, blk_col,
+#endif  // CONFIG_CROSS_CHROMA_TX
+        reduced_tx_set_used);
   } else {
-    return warehouse_efficients_txb(x, plane, block, tx_size, txb_ctx, p, eob,
-                                    plane_type, coeff_costs, xd, tx_type,
+    return warehouse_efficients_txb(
 #if CONFIG_CROSS_CHROMA_TX
-                                    cctx_type,
+        cm,
 #endif  // CONFIG_CROSS_CHROMA_TX
-                                    tx_class, reduced_tx_set_used);
+        x, plane, block, tx_size, txb_ctx, p, eob, plane_type, coeff_costs, xd,
+        tx_type,
+#if CONFIG_CROSS_CHROMA_TX
+        cctx_type, blk_row, blk_col,
+#endif  // CONFIG_CROSS_CHROMA_TX
+        tx_class, reduced_tx_set_used);
   }
 #else
-  return warehouse_efficients_txb(x, plane, block, tx_size, txb_ctx, p, eob,
-                                  plane_type, coeff_costs, xd, tx_type,
+  return warehouse_efficients_txb(
 #if CONFIG_CROSS_CHROMA_TX
-                                  cctx_type,
+      cm,
 #endif  // CONFIG_CROSS_CHROMA_TX
-                                  tx_class, reduced_tx_set_used);
+      x, plane, block, tx_size, txb_ctx, p, eob, plane_type, coeff_costs, xd,
+      tx_type,
+#if CONFIG_CROSS_CHROMA_TX
+      cctx_type, blk_row, blk_col,
+#endif  // CONFIG_CROSS_CHROMA_TX
+      tx_class, reduced_tx_set_used);
 #endif  // CONFIG_FORWARDSKIP
 }
 
 int av1_cost_coeffs_txb_laplacian(
-#if CONFIG_FORWARDSKIP
+#if CONFIG_FORWARDSKIP || CONFIG_CROSS_CHROMA_TX
     const AV1_COMMON *cm,
 #endif  // CONFIG_FORWARDSKIP
     const MACROBLOCK *x, const int plane, const int block,
     const TX_SIZE tx_size, const TX_TYPE tx_type,
 #if CONFIG_CROSS_CHROMA_TX
-    const CctxType cctx_type,
+    const CctxType cctx_type, int blk_row, int blk_col,
 #endif  // CONFIG_CROSS_CHROMA_TX
     const TXB_CTX *const txb_ctx, const int reduced_tx_set_used,
     const int adjust_eob) {
@@ -1512,7 +1514,8 @@
     skip_cost += coeff_costs->txb_skip_cost[txb_ctx->txb_skip_ctx][1];
 #endif  // CONFIG_CONTEXT_DERIVATION
 #if CONFIG_CROSS_CHROMA_TX
-    skip_cost += get_cctx_type_cost(x, xd, plane, tx_size, block, cctx_type);
+    skip_cost += get_cctx_type_cost(cm, x, xd, plane, tx_size, blk_row, blk_col,
+                                    block, cctx_type);
 #endif  // CONFIG_CROSS_CHROMA_TX
     return skip_cost;
   }
@@ -1524,13 +1527,13 @@
 #endif
 
   return warehouse_efficients_txb_laplacian(
-#if CONFIG_FORWARDSKIP
+#if CONFIG_FORWARDSKIP || CONFIG_CROSS_CHROMA_TX
       cm,
-#endif  // CONFIG_FORWARDSKIP
+#endif  // CONFIG_FORWARDSKIP || CONFIG_CROSS_CHROMA_TX
       x, plane, block, tx_size, txb_ctx, eob, plane_type, coeff_costs, xd,
       tx_type,
 #if CONFIG_CROSS_CHROMA_TX
-      cctx_type,
+      cctx_type, blk_row, blk_col,
 #endif  // CONFIG_CROSS_CHROMA_TX
       tx_class, reduced_tx_set_used);
 }
@@ -1951,7 +1954,7 @@
 int av1_optimize_txb_new(const struct AV1_COMP *cpi, MACROBLOCK *x, int plane,
                          int block, TX_SIZE tx_size, TX_TYPE tx_type,
 #if CONFIG_CROSS_CHROMA_TX
-                         CctxType cctx_type,
+                         CctxType cctx_type, int blk_row, int blk_col,
 #endif  // CONFIG_CROSS_CHROMA_TX
                          const TXB_CTX *const txb_ctx, int *rate_cost,
                          int sharpness) {
@@ -2164,7 +2167,8 @@
       av1_get_txb_entropy_context(qcoeff, scan_order, p->eobs[block]);
 
 #if CONFIG_CROSS_CHROMA_TX
-  accu_rate += get_cctx_type_cost(x, xd, plane, tx_size, block, cctx_type);
+  accu_rate += get_cctx_type_cost(cm, x, xd, plane, tx_size, blk_row, blk_col,
+                                  block, cctx_type);
 #endif  // CONFIG_CROSS_CHROMA_TX
 
   *rate_cost = accu_rate;
@@ -2205,11 +2209,26 @@
       !mbmi->skip_txfm[xd->tree_type == CHROMA_PART] &&
       !segfeature_active(&cm->seg, mbmi->segment_id, SEG_LVL_SKIP)) {
     const CctxType cctx_type = av1_get_cctx_type(xd, blk_row, blk_col);
+    int above_cctx, left_cctx;
+    get_above_and_left_cctx_type(cm, xd, blk_row, blk_col, tx_size, &above_cctx,
+                                 &left_cctx);
+    const int cctx_ctx = get_cctx_context(xd, above_cctx, left_cctx);
     if (allow_update_cdf)
-      update_cdf(fc->cctx_type_cdf[txsize_sqr_map[tx_size]], cctx_type,
-                 CCTX_TYPES);
+#if CCTX_ADAPT_REDUCED_SET
+      update_cdf(fc->cctx_type_cdf[txsize_sqr_map[tx_size]][cctx_ctx],
+                 cctx_type_to_idx(cctx_type, above_cctx, left_cctx),
+                 CCTX_TYPES_ALLOWED);
+#else
+      update_cdf(fc->cctx_type_cdf[txsize_sqr_map[tx_size]][cctx_ctx],
+                 cctx_type, CCTX_TYPES_ALLOWED);
+#endif
 #if CONFIG_ENTROPY_STATS
-    ++counts->cctx_type[txsize_sqr_map[tx_size]][cctx_type];
+#if CCTX_ADAPT_REDUCED_SET
+    ++counts->cctx_type[txsize_sqr_map[tx_size]][cctx_ctx]
+                       [cctx_type_to_idx(cctx_type, above_cctx, left_cctx)];
+#else
+    ++counts->cctx_type[txsize_sqr_map[tx_size]][cctx_ctx][cctx_type];
+#endif
 #endif  // CONFIG_ENTROPY_STATS
   }
 }
@@ -2659,17 +2678,9 @@
     eob_txb[block] = eob;
 
 #if CONFIG_CROSS_CHROMA_TX
-#if CCTX_C1_NONZERO
     if (plane == AOM_PLANE_U && eob > 0)
       update_cctx_type_count(cm, xd, blk_row, blk_col, tx_size, td->counts,
                              allow_update_cdf);
-#else
-    if (plane == AOM_PLANE_V &&
-        (eob > 0 || x->plane[AOM_PLANE_U].eobs[block] > 0)) {
-      update_cctx_type_count(cm, xd, blk_row, blk_col, tx_size, td->counts,
-                             allow_update_cdf);
-    }
-#endif  // CCTX_C1_NONZERO
 #endif  // CONFIG_CROSS_CHROMA_TX
     if (eob == 0) {
       av1_set_entropy_contexts(xd, pd, plane, plane_bsize, tx_size, 0, blk_col,
diff --git a/av1/encoder/encodetxb.h b/av1/encoder/encodetxb.h
index 7e1be18..3f2bdcc 100644
--- a/av1/encoder/encodetxb.h
+++ b/av1/encoder/encodetxb.h
@@ -98,7 +98,11 @@
  * \param[in]    tx_type              The transform type.*/
 #if CONFIG_CROSS_CHROMA_TX
 /* \param[in]    cctx_type            The cross chroma component transform
- * type*/
+ * type
+ * \param[in]    blk_row      The row index of the current transform block
+ * in the macroblock. Each unit has 4 pixels in y plane.
+ * \param[in]    blk_col      The col index of the current transform block
+ * in the macroblock. Each unit has 4 pixels in y plane.*/
 #endif  // CONFIG_CROSS_CHROMA_TX
 /* \param[in]    txb_ctx              Context info for entropy coding transform
  block
@@ -113,7 +117,7 @@
     const MACROBLOCK *x, const int plane, const int block,
     const TX_SIZE tx_size, const TX_TYPE tx_type,
 #if CONFIG_CROSS_CHROMA_TX
-    const CctxType cctx_type,
+    const CctxType cctx_type, int blk_row, int blk_col,
 #endif  // CONFIG_CROSS_CHROMA_TX
     const TXB_CTX *const txb_ctx, int reduced_tx_set_used);
 
@@ -147,7 +151,11 @@
  * \param[in]    tx_size        The transform size
  * \param[in]    tx_type        The transform type*/
 #if CONFIG_CROSS_CHROMA_TX
-/* \param[in]    cctx_type      The cross chroma component transform type*/
+/* \param[in]    cctx_type      The cross chroma component transform type
+ * \param[in]    blk_row      The row index of the current transform block
+ * in the macroblock. Each unit has 4 pixels in y plane.
+ * \param[in]    blk_col      The col index of the current transform block
+ * in the macroblock. Each unit has 4 pixels in y plane.*/
 #endif  // CONFIG_CROSS_CHROMA_TX
 /* \param[in]    txb_ctx        Context info for entropy coding transform block
  * skip flag (tx_skip) and the sign of DC coefficient (dc_sign).
@@ -166,7 +174,7 @@
     const MACROBLOCK *x, const int plane, const int block,
     const TX_SIZE tx_size, const TX_TYPE tx_type,
 #if CONFIG_CROSS_CHROMA_TX
-    const CctxType cctx_type,
+    const CctxType cctx_type, int blk_row, int blk_col,
 #endif  // CONFIG_CROSS_CHROMA_TX
     const TXB_CTX *const txb_ctx, const int reduced_tx_set_used,
     const int adjust_eob);
@@ -506,6 +514,10 @@
  * \param[in]    tx_size        The transform size
  * \param[in]    tx_type        The transform type
  * \param[in]    cctx_type      The cross chroma component transform type
+ * \param[in]    blk_row      The row index of the current transform block
+ * in the macroblock. Each unit has 4 pixels in y plane.
+ * \param[in]    blk_col      The col index of the current transform block
+ * in the macroblock. Each unit has 4 pixels in y plane.
  * \param[in]    txb_ctx        Context info for entropy coding transform block
  * skip flag (tx_skip) and the sign of DC coefficient (dc_sign).
  * \param[out]   rate_cost      The entropy cost of coding the transform block
@@ -551,7 +563,7 @@
 int av1_optimize_txb_new(const struct AV1_COMP *cpi, MACROBLOCK *x, int plane,
                          int block, TX_SIZE tx_size, TX_TYPE tx_type,
 #if CONFIG_CROSS_CHROMA_TX
-                         CctxType cctx_type,
+                         CctxType cctx_type, int blk_row, int blk_col,
 #endif  // CONFIG_CROSS_CHROMA_TX
                          const TXB_CTX *const txb_ctx, int *rate_cost,
                          int sharpness);
@@ -575,9 +587,32 @@
  */
 CB_COEFF_BUFFER *av1_get_cb_coeff_buffer(const struct AV1_COMP *cpi, int mi_row,
                                          int mi_col);
+
 #if CONFIG_CROSS_CHROMA_TX
-int get_cctx_type_cost(const MACROBLOCK *x, const MACROBLOCKD *xd, int plane,
-                       TX_SIZE tx_size, int block, CctxType cctx_type);
+/*!\brief Return the entropy cost associated with the cross chroma transform
+ *
+ * \ingroup coefficient_coding
+ *
+ * \param[in]    cm             Top-level structure shared by encoder and
+ decoder
+ * \param[in]    x              Pointer to structure holding the data for the
+                                current encoding macroblock
+ * \param[in]    xd             Pointer to structure holding the data for the
+                                current macroblockd
+ * \param[in]    plane          The index of the current plane
+ * \param[in]    tx_size        The transform size
+ * \param[in]    blk_row      The row index of the current transform block
+ * in the macroblock. Each unit has 4 pixels in y plane.
+ * \param[in]    blk_col      The col index of the current transform block
+ * in the macroblock. Each unit has 4 pixels in y plane.
+ * \param[in]    block          The index of the current transform block
+ * \param[in]    cctx_type      The cross chroma transform type
+ *
+ * \return       int            Entropy cost for cctx type
+ */
+int get_cctx_type_cost(const AV1_COMMON *cm, const MACROBLOCK *x,
+                       const MACROBLOCKD *xd, int plane, TX_SIZE tx_size,
+                       int blk_row, int blk_col, int block, CctxType cctx_type);
 #endif  // CONFIG_CROSS_CHROMA_TX
 
 #if CONFIG_CONTEXT_DERIVATION
diff --git a/av1/encoder/rd.c b/av1/encoder/rd.c
index 9eaec85..fd15af0 100644
--- a/av1/encoder/rd.c
+++ b/av1/encoder/rd.c
@@ -322,8 +322,10 @@
 
 #if CONFIG_CROSS_CHROMA_TX
   for (i = 0; i < EXT_TX_SIZES; ++i) {
-    av1_cost_tokens_from_cdf(mode_costs->cctx_type_cost[i],
-                             fc->cctx_type_cdf[i], NULL);
+    for (j = 0; j < CCTX_CONTEXTS; ++j) {
+      av1_cost_tokens_from_cdf(mode_costs->cctx_type_cost[i][j],
+                               fc->cctx_type_cdf[i][j], NULL);
+    }
   }
 #endif  // CONFIG_CROSS_CHROMA_TX
 
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 7baa75e..471743c 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -1626,7 +1626,7 @@
   uint8_t best_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE];
   TX_TYPE best_tx_type_map[MAX_MIB_SIZE * MAX_MIB_SIZE];
 #if CONFIG_CROSS_CHROMA_TX
-  TX_TYPE best_cctx_type_map[MAX_MIB_SIZE * MAX_MIB_SIZE];
+  CctxType best_cctx_type_map[MAX_MIB_SIZE * MAX_MIB_SIZE];
 #endif  // CONFIG_CROSS_CHROMA_TX
   const int rate_mv0 = *rate_mv;
   const int interintra_allowed =
@@ -3569,7 +3569,7 @@
   TX_TYPE best_tx_type_map[MAX_MIB_SIZE * MAX_MIB_SIZE];
   av1_copy_array(best_tx_type_map, xd->tx_type_map, ctx->num_4x4_blk);
 #if CONFIG_CROSS_CHROMA_TX
-  TX_TYPE best_cctx_type_map[MAX_MIB_SIZE * MAX_MIB_SIZE];
+  CctxType best_cctx_type_map[MAX_MIB_SIZE * MAX_MIB_SIZE];
   av1_copy_array(best_cctx_type_map, xd->cctx_type_map, ctx->num_4x4_blk);
 #endif  // CONFIG_CROSS_CHROMA_TX
 
diff --git a/av1/encoder/tx_search.c b/av1/encoder/tx_search.c
index 4e33bbb..a09bd53 100644
--- a/av1/encoder/tx_search.c
+++ b/av1/encoder/tx_search.c
@@ -1152,7 +1152,7 @@
       if (quant_param_intra.use_optimize_b) {
         av1_optimize_b(cpi, x, plane, block, tx_size, best_tx_type,
 #if CONFIG_CROSS_CHROMA_TX
-                       cctx_type,
+                       cctx_type, blk_row, blk_col,
 #endif  // CONFIG_CROSS_CHROMA_TX
                        txb_ctx, rate_cost);
       }
@@ -1315,9 +1315,7 @@
   memcpy(tmp_dqcoeff_v, p_v->dqcoeff + BLOCK_OFFSET(block),
          sizeof(tran_low_t) * eob_max);
 
-#if CCTX_C1_NONZERO
   assert(p_u->eobs[block] > 0);
-#endif
   assert(cpi != NULL);
   assert(tx_size_wide_log2[0] == tx_size_high_log2[0]);
 
@@ -1551,7 +1549,7 @@
 #endif  // CONFIG_FORWARDSKIP
         x, plane, block, tx_size, tx_type,
 #if CONFIG_CROSS_CHROMA_TX
-        CCTX_NONE,
+        CCTX_NONE, blk_row, blk_col,
 #endif  // CONFIG_CROSS_CHROMA_TX
         txb_ctx, reduced_tx_set_used, 0);
 
@@ -1592,7 +1590,7 @@
 #endif  // CONFIG_FORWARDSKIP
         x, plane, block, tx_size, tx_type,
 #if CONFIG_CROSS_CHROMA_TX
-        CCTX_NONE,
+        CCTX_NONE, blk_row, blk_col,
 #endif  // CONFIG_CROSS_CHROMA_TX
         txb_ctx, reduced_tx_set_used, 0);
 
@@ -1712,7 +1710,7 @@
 #endif  // CONFIG_FORWARDSKIP
         x, plane, block, tx_size, tx_type,
 #if CONFIG_CROSS_CHROMA_TX
-        CCTX_NONE,
+        CCTX_NONE, blk_row, blk_col,
 #endif  // CONFIG_CROSS_CHROMA_TX
         txb_ctx, reduced_tx_set_used, 0);
     // tx domain dist
@@ -2339,7 +2337,7 @@
 #endif  // CONFIG_FORWARDSKIP
     MACROBLOCK *x, int plane, int block, TX_SIZE tx_size, const TX_TYPE tx_type,
 #if CONFIG_CROSS_CHROMA_TX
-    const CctxType cctx_type,
+    const CctxType cctx_type, int blk_row, int blk_col,
 #endif  // CONFIG_CROSS_CHROMA_TX
     const TXB_CTX *const txb_ctx, int reduced_tx_set_used) {
 #if TXCOEFF_COST_TIMER
@@ -2352,7 +2350,7 @@
 #endif  // CONFIG_FORWARDSKIP
       x, plane, block, tx_size, tx_type,
 #if CONFIG_CROSS_CHROMA_TX
-      cctx_type,
+      cctx_type, blk_row, blk_col,
 #endif  // CONFIG_CROSS_CHROMA_TX
       txb_ctx, reduced_tx_set_used);
 #if TXCOEFF_COST_TIMER
@@ -2800,7 +2798,7 @@
       if (quant_param.use_optimize_b) {
         av1_optimize_b(cpi, x, plane, block, tx_size, tx_type,
 #if CONFIG_CROSS_CHROMA_TX
-                       CCTX_NONE,
+                       CCTX_NONE, blk_row, blk_col,
 #endif  // CONFIG_CROSS_CHROMA_TX
                        txb_ctx, &rate_cost);
       } else {
@@ -2810,7 +2808,7 @@
 #endif  // CONFIG_FORWARDSKIP
             x, plane, block, tx_size, tx_type,
 #if CONFIG_CROSS_CHROMA_TX
-            CCTX_NONE,
+            CCTX_NONE, blk_row, blk_col,
 #endif  // CONFIG_CROSS_CHROMA_TX
             txb_ctx, cm->features.reduced_tx_set_used);
       }
@@ -3055,12 +3053,23 @@
   memcpy(orig_coeff_v, p_v->coeff + BLOCK_OFFSET(block),
          sizeof(tran_low_t) * max_eob);
 
+#if CCTX_ADAPT_REDUCED_SET
+  int above_cctx, left_cctx;
+  get_above_and_left_cctx_type(cm, xd, blk_row, blk_col, tx_size, &above_cctx,
+                               &left_cctx);
+  uint8_t cctx_mask = get_allowed_cctx_mask(above_cctx, left_cctx);
+#endif
+
   // Iterate through all transform type candidates.
   for (CctxType cctx_type = CCTX_START; cctx_type < CCTX_TYPES; ++cctx_type) {
+#if CCTX_ADAPT_REDUCED_SET
+    if (!(cctx_mask & (1 << cctx_type))) continue;
+#endif
+
     RD_STATS this_rd_stats;
     av1_invalid_rd_stats(&this_rd_stats);
 
-    update_cctx_array(xd, blk_row, blk_col, tx_size, cctx_type);
+    update_cctx_array(xd, blk_row, blk_col, 0, 0, tx_size, cctx_type);
     forward_cross_chroma_transform(x, block, tx_size, cctx_type);
 
     for (int plane = AOM_PLANE_U; plane <= AOM_PLANE_V; plane++) {
@@ -3092,18 +3101,17 @@
       // Calculate rate cost of quantized coefficients.
       if (quant_param.use_optimize_b) {
         av1_optimize_b(cpi, x, plane, block, tx_size, tx_type, cctx_type,
-                       &txb_ctx_uv[plane - AOM_PLANE_U],
+                       blk_row, blk_col, &txb_ctx_uv[plane - AOM_PLANE_U],
                        &rate_cost[plane - AOM_PLANE_U]);
       } else {
         rate_cost[plane - AOM_PLANE_U] = cost_coeffs(
 #if CONFIG_FORWARDSKIP
             cm,
 #endif  // CONFIG_FORWARDSKIP
-            x, plane, block, tx_size, tx_type, cctx_type,
+            x, plane, block, tx_size, tx_type, cctx_type, blk_row, blk_col,
             &txb_ctx_uv[plane - AOM_PLANE_U], cm->features.reduced_tx_set_used);
       }
     }
-#if CCTX_C1_NONZERO
     // TODO(kslu) for negative angles, skip av1_xform_quant and reuse previous
     // dqcoeffs
     uint64_t sse_dqcoeff_u =
@@ -3121,7 +3129,6 @@
       }
       continue;
     }
-#endif
 
     // If rd cost based on coeff rate alone is already more than best_rd,
     // terminate early.
@@ -3171,15 +3178,13 @@
   assert(best_rd != INT64_MAX);
 
   best_rd_stats->skip_txfm = (best_eob_u == 0 && best_eob_v == 0);
-  update_cctx_array(xd, blk_row, blk_col, tx_size, best_cctx_type);
+  update_cctx_array(xd, blk_row, blk_col, 0, 0, tx_size, best_cctx_type);
   p_u->txb_entropy_ctx[block] = best_txb_ctx_u;
   p_v->txb_entropy_ctx[block] = best_txb_ctx_v;
   p_u->eobs[block] = best_eob_u;
   p_v->eobs[block] = best_eob_v;
 
-#if CCTX_C1_NONZERO
   assert(IMPLIES(best_cctx_type > CCTX_NONE, best_eob_u > 0));
-#endif
 #if CCTX_C2_DROPPED
   assert(IMPLIES(!keep_chroma_c2(best_cctx_type), best_eob_v == 0));
 #endif
@@ -4284,7 +4289,7 @@
   const AV1_COMMON *cm = &cpi->common;
   RD_STATS rd_stats_joint_uv;
   av1_init_rd_stats(&rd_stats_joint_uv);
-  update_cctx_array(xd, blk_row, blk_col, tx_size, CCTX_NONE);
+  update_cctx_array(xd, blk_row, blk_col, 0, 0, tx_size, CCTX_NONE);
 
   // Obtain RD cost for CCTX_NONE
   RD_STATS rd_stats_uv[2];
diff --git a/tools/aom_entropy_optimizer.c b/tools/aom_entropy_optimizer.c
index a72a843..c98d648 100644
--- a/tools/aom_entropy_optimizer.c
+++ b/tools/aom_entropy_optimizer.c
@@ -375,10 +375,11 @@
 #if CONFIG_CROSS_CHROMA_TX
   /* cctx type */
   cts_each_dim[0] = EXT_TX_SIZES;
-  cts_each_dim[1] = CCTX_TYPES;
-  optimize_cdf_table(&fc.cctx_type[0][0], probsfile, 2, cts_each_dim,
+  cts_each_dim[1] = CCTX_CONTEXTS;
+  cts_each_dim[2] = CCTX_TYPES_ALLOWED;
+  optimize_cdf_table(&fc.cctx_type[0][0][0], probsfile, 3, cts_each_dim,
                      "static const aom_cdf_prob default_cctx_type[EXT_TX_SIZES]"
-                     "[CDF_SIZE(CCTX_TYPES)]");
+                     "[CCTX_CONTEXTS][CDF_SIZE(CCTX_TYPES_ALLOWED)]");
 #endif  // CONFIG_CROSS_CHROMA_TX
 
   /* tx type */