Introduce CONFIG_CROSS_CHROMA_TX experiment

STATS_CHANGED

Change-Id: Id649b41d3d3ddccd6e594d31999c695622d576e2
diff --git a/av1/common/av1_txfm.c b/av1/common/av1_txfm.c
index 0fe1e7f..a38b11d 100644
--- a/av1/common/av1_txfm.c
+++ b/av1/common/av1_txfm.c
@@ -196,6 +196,11 @@
     };
 #endif  // CONFIG_DST_32X32
 
+#if CONFIG_CROSS_CHROMA_TX
+// Haar transform [1, 1; 1, -1] * 1/sqrt(2) * (1<<CCTX_PREC_BITS)
+const int32_t cctx_mtx[4] = { 181, 181, 181, -181 };
+#endif  // CONFIG_CROSS_CHROMA_TX
+
 // av1_sinpi_arr_data[i][j] = (int)round((sqrt(2) * sin(j*Pi/9) * 2 / 3) * (1
 // << (cos_bit_min + i))) modified so that elements j=1,2 sum to element j=4.
 const int32_t av1_sinpi_arr_data[7][5] = {
diff --git a/av1/common/av1_txfm.h b/av1/common/av1_txfm.h
index 5958434..f68cf5e 100644
--- a/av1/common/av1_txfm.h
+++ b/av1/common/av1_txfm.h
@@ -49,6 +49,12 @@
 #define KLT_PREC_BITS 10
 #endif  // CONFIG_DDT_INTER
 
+#if CONFIG_CROSS_CHROMA_TX
+#define CCTX_DC_ONLY 0
+#define CCTX_PREC_BITS 8
+extern const int32_t cctx_mtx[4];
+#endif  // CONFIG_CROSS_CHROMA_TX
+
 #define MAX_TXFM_STAGE_NUM 12
 
 static const int cos_bit_min = 10;
diff --git a/av1/common/idct.c b/av1/common/idct.c
index 7f97bf6..7e56a68 100644
--- a/av1/common/idct.c
+++ b/av1/common/idct.c
@@ -320,6 +320,29 @@
   }
 }
 
+#if CONFIG_CROSS_CHROMA_TX
+void av1_inv_cross_chroma_tx_block(tran_low_t *dqcoeff_u, tran_low_t *dqcoeff_v,
+                                   TX_SIZE tx_size) {
+#if CCTX_DC_ONLY
+  const int ncoeffs = 1;
+#else
+  const int ncoeffs = av1_get_max_eob(tx_size);
+#endif
+  // TODO(kslu): check if there is any overflow issue
+  // TODO(kslu): keep track of the EOB before fwd and after inv cctx
+  int32_t *src_u = (int32_t *)dqcoeff_u;
+  int32_t *src_v = (int32_t *)dqcoeff_v;
+  int32_t tmp[2] = { 0, 0 };
+
+  for (int i = 0; i < ncoeffs; i++) {
+    tmp[0] = cctx_mtx[0] * src_u[i] + cctx_mtx[2] * src_v[i];
+    tmp[1] = cctx_mtx[1] * src_u[i] + cctx_mtx[3] * src_v[i];
+    src_u[i] = ROUND_POWER_OF_TWO_SIGNED(tmp[0], CCTX_PREC_BITS);
+    src_v[i] = ROUND_POWER_OF_TWO_SIGNED(tmp[1], CCTX_PREC_BITS);
+  }
+}
+#endif  // CONFIG_CROSS_CHROMA_TX
+
 void av1_inverse_transform_block(const MACROBLOCKD *xd,
 #if CONFIG_IST
                                  tran_low_t *dqcoeff,
diff --git a/av1/common/idct.h b/av1/common/idct.h
index 216f5ca..62b5c62 100644
--- a/av1/common/idct.h
+++ b/av1/common/idct.h
@@ -33,6 +33,11 @@
 #define MAX_TX_SCALE 1
 int av1_get_tx_scale(const TX_SIZE tx_size);
 
+#if CONFIG_CROSS_CHROMA_TX
+void av1_inv_cross_chroma_tx_block(tran_low_t *dqcoeff_u, tran_low_t *dqcoeff_v,
+                                   TX_SIZE tx_size);
+#endif  // CONFIG_CROSS_CHROMA_TX
+
 void av1_inverse_transform_block(const MACROBLOCKD *xd,
 #if CONFIG_IST
                                  tran_low_t *dqcoeff,
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index 5ddf573..1c4664d 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -158,8 +158,24 @@
   tran_low_t *const dqcoeff = dcb->dqcoeff_block[plane] + dcb->cb_offset[plane];
 #endif
   eob_info *eob_data = dcb->eob_data[plane] + dcb->txb_offset[plane];
+#if CONFIG_CROSS_CHROMA_TX
+  // TODO(kslu): keep track of transform domain eobs of U and V
+  eob_info *eob_data_u =
+      dcb->eob_data[AOM_PLANE_U] + dcb->txb_offset[AOM_PLANE_U];
+  eob_info *eob_data_v =
+      dcb->eob_data[AOM_PLANE_V] + dcb->txb_offset[AOM_PLANE_V];
+  uint16_t scan_line = (plane == 0) ? eob_data->max_scan_line
+                                    : AOMMIN(av1_get_max_eob(tx_size),
+                                             AOMMAX(eob_data_u->max_scan_line,
+                                                    eob_data_v->max_scan_line));
+  uint16_t eob = (plane == 0)
+                     ? eob_data->eob
+                     : AOMMIN(av1_get_max_eob(tx_size),
+                              AOMMAX(eob_data_u->eob, eob_data_v->eob));
+#else
   uint16_t scan_line = eob_data->max_scan_line;
   uint16_t eob = eob_data->eob;
+#endif  // CONFIG_CROSS_CHROMA_TX
   av1_inverse_transform_block(&dcb->xd, dqcoeff, plane, tx_type, tx_size, dst,
                               stride, eob, reduced_tx_set);
 #if CONFIG_IST
@@ -249,6 +265,24 @@
   }
 }
 
+#if CONFIG_CROSS_CHROMA_TX
+static AOM_INLINE void inverse_cross_chroma_transform_block(
+    const AV1_COMMON *const cm, DecoderCodingBlock *dcb, aom_reader *const r,
+    const int plane, const int blk_row, const int blk_col,
+    const TX_SIZE tx_size) {
+  (void)cm;
+  (void)r;
+  (void)plane;
+  (void)blk_row;
+  (void)blk_col;
+  tran_low_t *dqcoeff_u =
+      dcb->dqcoeff_block[AOM_PLANE_U] + dcb->cb_offset[AOM_PLANE_U];
+  tran_low_t *dqcoeff_v =
+      dcb->dqcoeff_block[AOM_PLANE_V] + dcb->cb_offset[AOM_PLANE_V];
+  av1_inv_cross_chroma_tx_block(dqcoeff_u, dqcoeff_v, tx_size);
+}
+#endif  // CONFIG_CROSS_CHROMA_TX
+
 static AOM_INLINE void inverse_transform_inter_block(
     const AV1_COMMON *const cm, DecoderCodingBlock *dcb, aom_reader *const r,
     const int plane, const int blk_row, const int blk_col,
@@ -291,6 +325,9 @@
     AV1_COMMON *cm, ThreadData *const td, aom_reader *r,
     MB_MODE_INFO *const mbmi, int plane, BLOCK_SIZE plane_bsize, int blk_row,
     int blk_col, int block, TX_SIZE tx_size, int *eob_total) {
+#if CONFIG_CROSS_CHROMA_TX
+  if (plane == AOM_PLANE_U) return;
+#endif  // CONFIG_CROSS_CHROMA_TX
   DecoderCodingBlock *const dcb = &td->dcb;
   MACROBLOCKD *const xd = &dcb->xd;
   const struct macroblockd_plane *const pd = &xd->plane[plane];
@@ -308,6 +345,39 @@
   if (blk_row >= max_blocks_high || blk_col >= max_blocks_wide) return;
 
   if (tx_size == plane_tx_size || plane) {
+#if CONFIG_CROSS_CHROMA_TX
+    switch (plane) {
+      case AOM_PLANE_Y:
+        td->read_coeffs_tx_inter_block_visit(cm, dcb, r, plane, blk_row,
+                                             blk_col, tx_size);
+
+        td->inverse_tx_inter_block_visit(cm, dcb, r, plane, blk_row, blk_col,
+                                         tx_size);
+        eob_info *eob_data = dcb->eob_data[plane] + dcb->txb_offset[plane];
+        *eob_total += eob_data->eob;
+        set_cb_buffer_offsets(dcb, tx_size, plane);
+        break;
+      case AOM_PLANE_V:
+        td->read_coeffs_tx_inter_block_visit(cm, dcb, r, AOM_PLANE_U, blk_row,
+                                             blk_col, tx_size);
+        td->read_coeffs_tx_inter_block_visit(cm, dcb, r, AOM_PLANE_V, blk_row,
+                                             blk_col, tx_size);
+        td->inverse_cctx_block_visit(cm, dcb, r, -1, blk_row, blk_col, tx_size);
+        td->inverse_tx_inter_block_visit(cm, dcb, r, AOM_PLANE_U, blk_row,
+                                         blk_col, tx_size);
+        td->inverse_tx_inter_block_visit(cm, dcb, r, AOM_PLANE_V, blk_row,
+                                         blk_col, tx_size);
+        eob_info *eob_data_u =
+            dcb->eob_data[AOM_PLANE_U] + dcb->txb_offset[AOM_PLANE_U];
+        eob_info *eob_data_v =
+            dcb->eob_data[AOM_PLANE_V] + dcb->txb_offset[AOM_PLANE_V];
+        *eob_total += eob_data_u->eob + eob_data_v->eob;
+        set_cb_buffer_offsets(dcb, tx_size, AOM_PLANE_U);
+        set_cb_buffer_offsets(dcb, tx_size, AOM_PLANE_V);
+        break;
+      case AOM_PLANE_U: assert(0); break;
+    }
+#else
     td->read_coeffs_tx_inter_block_visit(cm, dcb, r, plane, blk_row, blk_col,
                                          tx_size);
 
@@ -316,6 +386,7 @@
     eob_info *eob_data = dcb->eob_data[plane] + dcb->txb_offset[plane];
     *eob_total += eob_data->eob;
     set_cb_buffer_offsets(dcb, tx_size, plane);
+#endif  // CONFIG_CROSS_CHROMA_TX
   } else {
 #if CONFIG_NEW_TX_PARTITION
     TX_SIZE sub_txs[MAX_TX_PARTITIONS] = { 0 };
@@ -1161,104 +1232,94 @@
   const int plane_start = get_partition_plane_start(xd->tree_type);
   const int plane_end =
       get_partition_plane_end(xd->tree_type, av1_num_planes(cm));
-  if (!is_inter_block(mbmi, xd->tree_type)) {
-    int row, col;
+
+  int row, col;
+  const int max_blocks_wide = max_block_wide(xd, bsize, 0);
+  const int max_blocks_high = max_block_high(xd, bsize, 0);
+  const BLOCK_SIZE max_unit_bsize = BLOCK_64X64;
+  int mu_blocks_wide = mi_size_wide[max_unit_bsize];
+  int mu_blocks_high = mi_size_high[max_unit_bsize];
+  mu_blocks_wide = AOMMIN(max_blocks_wide, mu_blocks_wide);
+  mu_blocks_high = AOMMIN(max_blocks_high, mu_blocks_high);
+
+  const int is_inter = is_inter_block(mbmi, xd->tree_type);
+  if (!is_inter) {
     assert(bsize == get_plane_block_size(bsize, xd->plane[0].subsampling_x,
                                          xd->plane[0].subsampling_y));
-    const int max_blocks_wide = max_block_wide(xd, bsize, 0);
-    const int max_blocks_high = max_block_high(xd, bsize, 0);
-    const BLOCK_SIZE max_unit_bsize = BLOCK_64X64;
-    int mu_blocks_wide = mi_size_wide[max_unit_bsize];
-    int mu_blocks_high = mi_size_high[max_unit_bsize];
-    mu_blocks_wide = AOMMIN(max_blocks_wide, mu_blocks_wide);
-    mu_blocks_high = AOMMIN(max_blocks_high, mu_blocks_high);
-
-    for (row = 0; row < max_blocks_high; row += mu_blocks_high) {
-      for (col = 0; col < max_blocks_wide; col += mu_blocks_wide) {
-        for (int plane = plane_start; plane < plane_end; ++plane) {
-          if (plane && !xd->is_chroma_ref) break;
-          const struct macroblockd_plane *const pd = &xd->plane[plane];
-          const TX_SIZE tx_size = av1_get_tx_size(plane, xd);
-          const int stepr = tx_size_high_unit[tx_size];
-          const int stepc = tx_size_wide_unit[tx_size];
-
-          const int unit_height = ROUND_POWER_OF_TWO(
-              AOMMIN(mu_blocks_high + row, max_blocks_high), pd->subsampling_y);
-          const int unit_width = ROUND_POWER_OF_TWO(
-              AOMMIN(mu_blocks_wide + col, max_blocks_wide), pd->subsampling_x);
-
-          for (int blk_row = row >> pd->subsampling_y; blk_row < unit_height;
-               blk_row += stepr) {
-            for (int blk_col = col >> pd->subsampling_x; blk_col < unit_width;
-                 blk_col += stepc) {
-              td->read_coeffs_tx_intra_block_visit(cm, dcb, r, plane, blk_row,
-                                                   blk_col, tx_size);
-              td->predict_and_recon_intra_block_visit(
-                  cm, dcb, r, plane, blk_row, blk_col, tx_size);
-              set_cb_buffer_offsets(dcb, tx_size, plane);
-            }
-          }
-        }
-      }
-    }
   } else {
     td->predict_inter_block_visit(cm, dcb, bsize);
-    // Reconstruction
-    if (!mbmi->skip_txfm[xd->tree_type == CHROMA_PART]) {
-      int eobtotal = 0;
-
-      const int max_blocks_wide = max_block_wide(xd, bsize, 0);
-      const int max_blocks_high = max_block_high(xd, bsize, 0);
-      int row, col;
-
-      const BLOCK_SIZE max_unit_bsize = BLOCK_64X64;
+    if (!mbmi->skip_txfm[xd->tree_type == CHROMA_PART])
       assert(max_unit_bsize ==
              get_plane_block_size(BLOCK_64X64, xd->plane[0].subsampling_x,
                                   xd->plane[0].subsampling_y));
-      int mu_blocks_wide = mi_size_wide[max_unit_bsize];
-      int mu_blocks_high = mi_size_high[max_unit_bsize];
+  }
 
-      mu_blocks_wide = AOMMIN(max_blocks_wide, mu_blocks_wide);
-      mu_blocks_high = AOMMIN(max_blocks_high, mu_blocks_high);
+  for (row = 0; row < max_blocks_high; row += mu_blocks_high) {
+    for (col = 0; col < max_blocks_wide; col += mu_blocks_wide) {
+      for (int plane = plane_start; plane < plane_end; ++plane) {
+        if (plane && !xd->is_chroma_ref) break;
+        const struct macroblockd_plane *const pd = &xd->plane[plane];
+        const TX_SIZE tx_size = av1_get_tx_size(plane, xd);
+        const int ss_x = pd->subsampling_x;
+        const int ss_y = pd->subsampling_y;
+        const int unit_height = ROUND_POWER_OF_TWO(
+            AOMMIN(mu_blocks_high + row, max_blocks_high), ss_y);
+        const int unit_width = ROUND_POWER_OF_TWO(
+            AOMMIN(mu_blocks_wide + col, max_blocks_wide), ss_x);
 
-      for (row = 0; row < max_blocks_high; row += mu_blocks_high) {
-        for (col = 0; col < max_blocks_wide; col += mu_blocks_wide) {
-          for (int plane = plane_start; plane < plane_end; ++plane) {
-            if (plane && !xd->is_chroma_ref) break;
-            const struct macroblockd_plane *const pd = &xd->plane[plane];
-            const int ss_x = pd->subsampling_x;
-            const int ss_y = pd->subsampling_y;
-            const BLOCK_SIZE plane_bsize =
-                get_plane_block_size(bsize, ss_x, ss_y);
-            const TX_SIZE max_tx_size =
-                get_vartx_max_txsize(xd, plane_bsize, plane);
-            const int bh_var_tx = tx_size_high_unit[max_tx_size];
-            const int bw_var_tx = tx_size_wide_unit[max_tx_size];
-            int block = 0;
-            int step =
-                tx_size_wide_unit[max_tx_size] * tx_size_high_unit[max_tx_size];
-            int blk_row, blk_col;
-            const int unit_height = ROUND_POWER_OF_TWO(
-                AOMMIN(mu_blocks_high + row, max_blocks_high), ss_y);
-            const int unit_width = ROUND_POWER_OF_TWO(
-                AOMMIN(mu_blocks_wide + col, max_blocks_wide), ss_x);
-
-            for (blk_row = row >> ss_y; blk_row < unit_height;
-                 blk_row += bh_var_tx) {
-              for (blk_col = col >> ss_x; blk_col < unit_width;
-                   blk_col += bw_var_tx) {
+        const BLOCK_SIZE plane_bsize = get_plane_block_size(bsize, ss_x, ss_y);
+        const TX_SIZE max_tx_size =
+            get_vartx_max_txsize(xd, plane_bsize, plane);
+        const int stepr = is_inter ? tx_size_high_unit[max_tx_size]
+                                   : tx_size_high_unit[tx_size];
+        const int stepc = is_inter ? tx_size_wide_unit[max_tx_size]
+                                   : tx_size_wide_unit[tx_size];
+        int eobtotal = 0;
+        int block = 0;
+        for (int blk_row = row >> ss_y; blk_row < unit_height;
+             blk_row += stepr) {
+          for (int blk_col = col >> ss_x; blk_col < unit_width;
+               blk_col += stepc) {
+            if (!is_inter) {
+              td->read_coeffs_tx_intra_block_visit(cm, dcb, r, plane, blk_row,
+                                                   blk_col, tx_size);
+#if CONFIG_CROSS_CHROMA_TX
+              switch (plane) {
+                case AOM_PLANE_Y:
+                  td->predict_and_recon_intra_block_visit(
+                      cm, dcb, r, AOM_PLANE_Y, blk_row, blk_col, tx_size);
+                  set_cb_buffer_offsets(dcb, tx_size, AOM_PLANE_Y);
+                  break;
+                case AOM_PLANE_U: break;
+                case AOM_PLANE_V:
+                  td->predict_and_recon_intra_block_visit(
+                      cm, dcb, r, AOM_PLANE_U, blk_row, blk_col, tx_size);
+                  set_cb_buffer_offsets(dcb, tx_size, AOM_PLANE_U);
+                  td->predict_and_recon_intra_block_visit(
+                      cm, dcb, r, AOM_PLANE_V, blk_row, blk_col, tx_size);
+                  set_cb_buffer_offsets(dcb, tx_size, AOM_PLANE_V);
+                  break;
+              }
+#else
+              td->predict_and_recon_intra_block_visit(
+                  cm, dcb, r, plane, blk_row, blk_col, tx_size);
+              set_cb_buffer_offsets(dcb, tx_size, plane);
+#endif  // CONFIG_CROSS_CHROMA_TX
+            } else {
+              // Reconstruction
+              if (!mbmi->skip_txfm[xd->tree_type == CHROMA_PART]) {
                 decode_reconstruct_tx(cm, td, r, mbmi, plane, plane_bsize,
                                       blk_row, blk_col, block, max_tx_size,
                                       &eobtotal);
-                block += step;
+                block += stepr * stepc;
               }
             }
           }
         }
       }
     }
-    td->cfl_store_inter_block_visit(cm, xd);
   }
+  if (is_inter) td->cfl_store_inter_block_visit(cm, xd);
 
   av1_visit_palette(pbi, xd, r, set_color_index_map_offset);
 }
@@ -3417,6 +3478,9 @@
   td->predict_and_recon_intra_block_visit = decode_block_void;
   td->read_coeffs_tx_inter_block_visit = decode_block_void;
   td->inverse_tx_inter_block_visit = decode_block_void;
+#if CONFIG_CROSS_CHROMA_TX
+  td->inverse_cctx_block_visit = decode_block_void;
+#endif  // CONFIG_CROSS_CHROMA_TX
   td->predict_inter_block_visit = predict_inter_block_void;
   td->cfl_store_inter_block_visit = cfl_store_inter_block_void;
 
@@ -3428,6 +3492,9 @@
     td->predict_and_recon_intra_block_visit =
         predict_and_reconstruct_intra_block;
     td->inverse_tx_inter_block_visit = inverse_transform_inter_block;
+#if CONFIG_CROSS_CHROMA_TX
+    td->inverse_cctx_block_visit = inverse_cross_chroma_transform_block;
+#endif  // CONFIG_CROSS_CHROMA_TX
     td->predict_inter_block_visit = predict_inter_block;
     td->cfl_store_inter_block_visit = cfl_store_inter_block;
   }
diff --git a/av1/decoder/decoder.h b/av1/decoder/decoder.h
index 78b6fc9..5f151cd 100644
--- a/av1/decoder/decoder.h
+++ b/av1/decoder/decoder.h
@@ -125,6 +125,9 @@
   decode_block_visitor_fn_t read_coeffs_tx_intra_block_visit;
   decode_block_visitor_fn_t predict_and_recon_intra_block_visit;
   decode_block_visitor_fn_t read_coeffs_tx_inter_block_visit;
+#if CONFIG_CROSS_CHROMA_TX
+  decode_block_visitor_fn_t inverse_cctx_block_visit;
+#endif  // CONFIG_CROSS_CHROMA_TX
   decode_block_visitor_fn_t inverse_tx_inter_block_visit;
   predict_inter_block_visitor_fn_t predict_inter_block_visit;
   cfl_store_inter_block_visitor_fn_t cfl_store_inter_block_visit;
diff --git a/av1/encoder/encodemb.c b/av1/encoder/encodemb.c
index 003be64..3cc29b1 100644
--- a/av1/encoder/encodemb.c
+++ b/av1/encoder/encodemb.c
@@ -273,11 +273,45 @@
   const struct macroblock_plane *const p = &x->plane[plane];
   const int is_inter = is_inter_block(mbmi, xd->tree_type);
 #endif  // CONFIG_FORWARDSKIP
+#if CONFIG_CROSS_CHROMA_TX
+  if (is_inter_block(x->e_mbd.mi[0], x->e_mbd.tree_type)) {
+    switch (plane) {
+      case AOM_PLANE_Y:
 #if CONFIG_IST
-  av1_xform(x, plane, block, blk_row, blk_col, plane_bsize, txfm_param, 0);
+        av1_xform(x, AOM_PLANE_Y, block, blk_row, blk_col, plane_bsize,
+                  txfm_param, 0);
+#else
+        av1_xform(x, AOM_PLANE_Y, block, blk_row, blk_col, plane_bsize,
+                  txfm_param);
+#endif
+        break;
+      case AOM_PLANE_U:
+#if CONFIG_IST
+        av1_xform(x, AOM_PLANE_U, block, blk_row, blk_col, plane_bsize,
+                  txfm_param, 0);
+        av1_xform(x, AOM_PLANE_V, block, blk_row, blk_col, plane_bsize,
+                  txfm_param, 0);
+#else
+        av1_xform(x, AOM_PLANE_U, block, blk_row, blk_col, plane_bsize,
+                  txfm_param);
+        av1_xform(x, AOM_PLANE_V, block, blk_row, blk_col, plane_bsize,
+                  txfm_param);
+#endif
+        forward_cross_chroma_transform(x, block, txfm_param->tx_size);
+        // TODO(kslu): maybe skip av1_setup_xform for V
+        break;
+      case AOM_PLANE_V: break;
+    }
+  } else {
+#endif  // CONFIG_CROSS_CHROMA_TX
+#if CONFIG_IST
+    av1_xform(x, plane, block, blk_row, blk_col, plane_bsize, txfm_param, 0);
 #else
   av1_xform(x, plane, block, blk_row, blk_col, plane_bsize, txfm_param);
 #endif
+#if CONFIG_CROSS_CHROMA_TX
+  }
+#endif  // CONFIG_CROSS_CHROMA_TX
 #if CONFIG_FORWARDSKIP
   const uint8_t fsc_mode =
       (mbmi->fsc_mode[xd->tree_type == CHROMA_PART] && plane == PLANE_TYPE_Y) ||
@@ -355,6 +389,17 @@
 #endif
 }
 
+#if CONFIG_CROSS_CHROMA_TX
+void forward_cross_chroma_transform(MACROBLOCK *x, int block, TX_SIZE tx_size) {
+  struct macroblock_plane *const p_u = &x->plane[AOM_PLANE_U];
+  struct macroblock_plane *const p_v = &x->plane[AOM_PLANE_V];
+  const int block_offset = BLOCK_OFFSET(block);
+  tran_low_t *coeff_u = p_u->coeff + block_offset;
+  tran_low_t *coeff_v = p_v->coeff + block_offset;
+  av1_fwd_cross_chroma_tx_block(coeff_u, coeff_v, tx_size);
+}
+#endif  // CONFIG_CROSS_CHROMA_TX
+
 void av1_quant(MACROBLOCK *x, int plane, int block, TxfmParam *txfm_param,
                QUANT_PARAM *qparam) {
   const struct macroblock_plane *const p = &x->plane[plane];
@@ -483,7 +528,7 @@
   MB_MODE_INFO *mbmi = xd->mi[0];
   struct macroblock_plane *const p = &x->plane[plane];
   struct macroblockd_plane *const pd = &xd->plane[plane];
-#if CONFIG_IST
+#if CONFIG_IST || CONFIG_CROSS_CHROMA_TX
   tran_low_t *dqcoeff = p->dqcoeff + BLOCK_OFFSET(block);
 #else
   tran_low_t *const dqcoeff = p->dqcoeff + BLOCK_OFFSET(block);
@@ -575,13 +620,45 @@
 
   av1_set_txb_context(x, plane, block, tx_size, a, l);
 
-  if (p->eobs[block]) {
-    *(args->skip) = 0;
-    av1_inverse_transform_block(xd, dqcoeff, plane, tx_type, tx_size, dst,
-                                pd->dst.stride, p->eobs[block],
+#if CONFIG_CROSS_CHROMA_TX
+  // In CONFIG_CROSS_CHROMA_TX, reconstruction for U plane relies on dqcoeffs of
+  // V plane, so the below operations for U are performed together with V once
+  // dqcoeffs of V are obtained.
+  if (is_inter_block(mbmi, xd->tree_type) && plane == AOM_PLANE_U) {
+    if (p->eobs[block]) *(args->skip) = 0;
+    return;
+  } else if (is_inter_block(mbmi, xd->tree_type) && plane == AOM_PLANE_V) {
+    struct macroblock_plane *const p_u = &x->plane[AOM_PLANE_U];
+    tran_low_t *dqcoeff_u = x->plane[AOM_PLANE_U].dqcoeff + BLOCK_OFFSET(block);
+    struct macroblockd_plane *const pd_u = &xd->plane[AOM_PLANE_U];
+    uint8_t *dst_u =
+        &pd_u->dst.buf[(blk_row * pd_u->dst.stride + blk_col) << MI_SIZE_LOG2];
+    av1_inv_cross_chroma_tx_block(dqcoeff_u, dqcoeff, tx_size);
+    av1_inverse_transform_block(xd, dqcoeff_u, AOM_PLANE_U, tx_type, tx_size,
+                                dst_u, pd_u->dst.stride,
+                                AOMMAX(p_u->eobs[block], p->eobs[block]),
                                 cm->features.reduced_tx_set_used);
   }
 
+  // TODO(kslu): keep track of transform domain eobs for U and V
+  if (p->eobs[block] || (plane && (x->plane[AOM_PLANE_U].eobs[block] ||
+                                   x->plane[AOM_PLANE_V].eobs[block]))) {
+#else
+  if (p->eobs[block]) {
+#endif  // CONFIG_CROSS_CHROMA_TX
+    *(args->skip) = 0;
+    av1_inverse_transform_block(
+        xd, dqcoeff, plane, tx_type, tx_size, dst, pd->dst.stride,
+#if CONFIG_CROSS_CHROMA_TX
+        (plane == 0) ? p->eobs[block]
+                     : AOMMAX(x->plane[AOM_PLANE_U].eobs[block],
+                              x->plane[AOM_PLANE_V].eobs[block]),
+#else
+        p->eobs[block],
+#endif
+        cm->features.reduced_tx_set_used);
+  }
+
   // TODO(debargha, jingning): Temporarily disable txk_type check for eob=0
   // case. It is possible that certain collision in hash index would cause
   // the assertion failure. To further optimize the rate-distortion
@@ -611,6 +688,18 @@
     int blk_h = block_size_high[bsize];
     mi_to_pixel_loc(&pixel_c, &pixel_r, xd->mi_col, xd->mi_row, blk_col,
                     blk_row, pd->subsampling_x, pd->subsampling_y);
+#if CONFIG_CROSS_CHROMA_TX
+    if (plane == AOM_PLANE_V) {
+      struct macroblockd_plane *const pd_u = &xd->plane[AOM_PLANE_U];
+      uint8_t *dst_u =
+          &pd_u->dst
+               .buf[(blk_row * pd_u->dst.stride + blk_col) << MI_SIZE_LOG2];
+      mismatch_record_block_tx(dst_u, pd_u->dst.stride,
+                               cm->current_frame.order_hint, AOM_PLANE_U,
+                               pixel_c, pixel_r, blk_w, blk_h,
+                               xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH);
+    }
+#endif  // CONFIG_CROSS_CHROMA_TX
     mismatch_record_block_tx(dst, pd->dst.stride, cm->current_frame.order_hint,
                              plane, pixel_c, pixel_r, blk_w, blk_h);
   }
@@ -808,6 +897,17 @@
     cpi,  x,    &ctx,    &mbmi->skip_txfm[xd->tree_type == CHROMA_PART],
     NULL, NULL, dry_run, cpi->optimize_seg_arr[mbmi->segment_id]
   };
+#if CONFIG_CROSS_CHROMA_TX
+  // Subtract first, so both U and V residues will be available when U component
+  // is being transformed and quantized.
+  for (int plane = plane_start; plane < plane_end; ++plane) {
+    const struct macroblockd_plane *const pd = &xd->plane[plane];
+    if (plane && !xd->is_chroma_ref) break;
+    const BLOCK_SIZE plane_bsize =
+        get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y);
+    av1_subtract_plane(x, plane_bsize, plane);
+  }
+#endif  // CONFIG_CROSS_CHROMA_TX
   for (int plane = plane_start; plane < plane_end; ++plane) {
     const struct macroblockd_plane *const pd = &xd->plane[plane];
     const int subsampling_x = pd->subsampling_x;
@@ -826,7 +926,10 @@
     const int step =
         tx_size_wide_unit[max_tx_size] * tx_size_high_unit[max_tx_size];
     av1_get_entropy_contexts(plane_bsize, pd, ctx.ta[plane], ctx.tl[plane]);
+#if !CONFIG_CROSS_CHROMA_TX
     av1_subtract_plane(x, plane_bsize, plane);
+#endif  // !CONFIG_CROSS_CHROMA_TX
+
     arg.ta = ctx.ta[plane];
     arg.tl = ctx.tl[plane];
     const BLOCK_SIZE max_unit_bsize =
@@ -1005,6 +1108,7 @@
   }
 
   if (*eob) {
+    // TODO(kslu) apply inv cctx for u plane once it is needed for intra
     av1_inverse_transform_block(xd, dqcoeff, plane, tx_type, tx_size, dst,
                                 dst_stride, *eob,
                                 cm->features.reduced_tx_set_used);
diff --git a/av1/encoder/encodemb.h b/av1/encoder/encodemb.h
index c289094..47f0b36 100644
--- a/av1/encoder/encodemb.h
+++ b/av1/encoder/encodemb.h
@@ -109,6 +109,10 @@
 #endif
 );
 
+#if CONFIG_CROSS_CHROMA_TX
+void forward_cross_chroma_transform(MACROBLOCK *x, int block, TX_SIZE tx_size);
+#endif  // CONFIG_CROSS_CHROMA_TX
+
 void av1_quant(MACROBLOCK *x, int plane, int block, TxfmParam *txfm_param,
                QUANT_PARAM *qparam);
 
diff --git a/av1/encoder/hybrid_fwd_txfm.c b/av1/encoder/hybrid_fwd_txfm.c
index 5ef2ff8..001f672 100644
--- a/av1/encoder/hybrid_fwd_txfm.c
+++ b/av1/encoder/hybrid_fwd_txfm.c
@@ -499,6 +499,27 @@
   }
 }
 
+#if CONFIG_CROSS_CHROMA_TX
+void av1_fwd_cross_chroma_tx_block(tran_low_t *coeff_u, tran_low_t *coeff_v,
+                                   TX_SIZE tx_size) {
+#if CCTX_DC_ONLY
+  const int ncoeffs = 1;
+#else
+  const int ncoeffs = av1_get_max_eob(tx_size);
+#endif
+  int32_t *src_u = (int32_t *)coeff_u;
+  int32_t *src_v = (int32_t *)coeff_v;
+  int32_t tmp[2] = { 0, 0 };
+
+  for (int i = 0; i < ncoeffs; i++) {
+    tmp[0] = cctx_mtx[0] * src_u[i] + cctx_mtx[1] * src_v[i];
+    tmp[1] = cctx_mtx[2] * src_u[i] + cctx_mtx[3] * src_v[i];
+    src_u[i] = ROUND_POWER_OF_TWO_SIGNED(tmp[0], CCTX_PREC_BITS);
+    src_v[i] = ROUND_POWER_OF_TWO_SIGNED(tmp[1], CCTX_PREC_BITS);
+  }
+}
+#endif  // CONFIG_CROSS_CHROMA_TX
+
 #if CONFIG_IST
 void av1_fwd_stxfm(tran_low_t *coeff, TxfmParam *txfm_param) {
   const TX_TYPE stx_type = txfm_param->sec_tx_type;
diff --git a/av1/encoder/hybrid_fwd_txfm.h b/av1/encoder/hybrid_fwd_txfm.h
index e171f67..6bc7ffb 100644
--- a/av1/encoder/hybrid_fwd_txfm.h
+++ b/av1/encoder/hybrid_fwd_txfm.h
@@ -25,6 +25,11 @@
 void av1_highbd_fwd_txfm(const int16_t *src_diff, tran_low_t *coeff,
                          int diff_stride, TxfmParam *txfm_param);
 
+#if CONFIG_CROSS_CHROMA_TX
+void av1_fwd_cross_chroma_tx_block(tran_low_t *dqcoeff_u, tran_low_t *dqcoeff_v,
+                                   TX_SIZE tx_size);
+#endif  // CONFIG_CROSS_CHROMA_TX
+
 #if CONFIG_IST
 void av1_fwd_stxfm(tran_low_t *coeff, TxfmParam *txfm_param);
 #endif
diff --git a/av1/encoder/tx_search.c b/av1/encoder/tx_search.c
index 3216e4a..4db1c8c 100644
--- a/av1/encoder/tx_search.c
+++ b/av1/encoder/tx_search.c
@@ -1142,6 +1142,7 @@
       }
     }
 
+    // TODO(kslu) apply inv cctx for u plane once it is needed for intra
     inverse_transform_block_facade(x, plane, block, blk_row, blk_col,
                                    x->plane[plane].eobs[block],
                                    cm->features.reduced_tx_set_used);
@@ -1217,7 +1218,7 @@
   const int dst_idx = (blk_row * dst_stride + blk_col) << MI_SIZE_LOG2;
   const uint8_t *src = &x->plane[plane].src.buf[src_idx];
   const uint8_t *dst = &xd->plane[plane].dst.buf[dst_idx];
-#if CONFIG_IST
+#if CONFIG_IST || CONFIG_CROSS_CHROMA_TX
   tran_low_t *dqcoeff = p->dqcoeff + BLOCK_OFFSET(block);
 #else
   const tran_low_t *dqcoeff = p->dqcoeff + BLOCK_OFFSET(block);
@@ -1236,6 +1237,12 @@
   const PLANE_TYPE plane_type = get_plane_type(plane);
   TX_TYPE tx_type = av1_get_tx_type(xd, plane_type, blk_row, blk_col, tx_size,
                                     cpi->common.features.reduced_tx_set_used);
+#if CONFIG_CROSS_CHROMA_TX
+  if (is_inter_block(xd->mi[0], xd->tree_type) && plane == AOM_PLANE_U) {
+    tran_low_t *dqcoeff_v = x->plane[AOM_PLANE_V].dqcoeff + BLOCK_OFFSET(block);
+    av1_inv_cross_chroma_tx_block(dqcoeff, dqcoeff_v, tx_size);
+  }
+#endif  // CONFIG_CROSS_CHROMA_TX
   av1_inverse_transform_block(xd, dqcoeff, plane, tx_type, tx_size, recon,
                               MAX_TX_SIZE, eob,
                               cpi->common.features.reduced_tx_set_used);
@@ -2601,15 +2608,60 @@
       RD_STATS this_rd_stats;
       av1_invalid_rd_stats(&this_rd_stats);
 
-      if (!dc_only_blk)
+#if CONFIG_CROSS_CHROMA_TX
+      if (is_inter_block(mbmi, xd->tree_type)) {
+        switch (plane) {
+          case AOM_PLANE_Y:
+            if (!dc_only_blk) {
 #if CONFIG_IST
-        av1_xform(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
-                  1);
+              av1_xform(x, AOM_PLANE_Y, block, blk_row, blk_col, plane_bsize,
+                        &txfm_param, 1);
+#else
+              av1_xform(x, AOM_PLANE_Y, block, blk_row, blk_col, plane_bsize,
+                        &txfm_param);
+#endif
+            } else {
+              av1_xform_dc_only(x, AOM_PLANE_Y, block, &txfm_param,
+                                per_px_mean);
+            }
+            break;
+          case AOM_PLANE_U:
+            if (!dc_only_blk) {
+#if CONFIG_IST
+              av1_xform(x, AOM_PLANE_U, block, blk_row, blk_col, plane_bsize,
+                        &txfm_param, 1);
+              av1_xform(x, AOM_PLANE_V, block, blk_row, blk_col, plane_bsize,
+                        &txfm_param, 1);
+#else
+              av1_xform(x, AOM_PLANE_U, block, blk_row, blk_col, plane_bsize,
+                        &txfm_param);
+              av1_xform(x, AOM_PLANE_V, block, blk_row, blk_col, plane_bsize,
+                        &txfm_param);
+#endif
+            } else {
+              av1_xform_dc_only(x, AOM_PLANE_U, block, &txfm_param,
+                                per_px_mean);
+              av1_xform_dc_only(x, AOM_PLANE_V, block, &txfm_param,
+                                per_px_mean);
+            }
+            forward_cross_chroma_transform(x, block, txfm_param.tx_size);
+            break;
+          case AOM_PLANE_V: break;
+        }
+      } else {
+#endif
+        if (!dc_only_blk)
+#if CONFIG_IST
+          av1_xform(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
+                    1);
 #else
       av1_xform(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param);
 #endif
-      else
-        av1_xform_dc_only(x, plane, block, &txfm_param, per_px_mean);
+        else
+          av1_xform_dc_only(x, plane, block, &txfm_param, per_px_mean);
+#if CONFIG_CROSS_CHROMA_TX
+      }
+#endif  // CONFIG_CROSS_CHROMA_TX
 
 #if CONFIG_IST
       skip_trellis_based_on_satd[txfm_param.tx_type] =
diff --git a/build/cmake/aom_config_defaults.cmake b/build/cmake/aom_config_defaults.cmake
index 46c4535..14e2a17 100644
--- a/build/cmake/aom_config_defaults.cmake
+++ b/build/cmake/aom_config_defaults.cmake
@@ -208,6 +208,8 @@
 set_aom_config_var(CONFIG_DST_32X32 0 NUMBER "AV2 DST7 32x32 experiment flag.")
 set_aom_config_var(CONFIG_DDT_INTER 0 NUMBER
                    "AV2 data-driven inter transform experiment flag.")
+set_aom_config_var(CONFIG_CROSS_CHROMA_TX 0
+                   "AV2 cross chroma component transform experiment flag.")
 #
 # Variables in this section control optional features of the build system.
 #