[tcq] Simplify get_rate*_c functions.

Split luma/chroma versions of get_rate_dist*
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index 8e613bc..ccb71ae 100644
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -330,8 +330,10 @@
     specialize qw/av1_get_rate_dist_def_luma avx2/;
     add_proto qw/void av1_get_rate_dist_def_chroma/, "const struct LV_MAP_COEFF_COST* txb_costs, const struct prequant_t *pq, const uint8_t coeff_ctx[TOTALSTATES + 4], int diag_ctx, int plane, int t_sign, int sign, int32_t rate_zero[TOTALSTATES], int32_t rate[2 * TOTALSTATES], int64_t dist[2 * TOTALSTATES]";
     specialize qw/av1_get_rate_dist_def_chroma avx2/;
-    add_proto qw/void av1_get_rate_dist_lf/, "const struct LV_MAP_COEFF_COST *txb_costs, const struct prequant_t *pq, const uint8_t coeff_ctx[TOTALSTATES + 4], int diag_ctx, int dc_sign_ctx, int32_t *tmp_sign, int bwl, TX_CLASS tx_class, int plane, int blk_pos, int coeff_sign, int32_t rate_zero[TOTALSTATES], int32_t rate[2 * TOTALSTATES], int64_t dist[2 * TOTALSTATES]";
-    specialize qw/av1_get_rate_dist_lf avx2/;
+    add_proto qw/void av1_get_rate_dist_lf_luma/, "const struct LV_MAP_COEFF_COST *txb_costs, const struct prequant_t *pq, const uint8_t coeff_ctx[TOTALSTATES + 4], int diag_ctx, int dc_sign_ctx, int32_t *tmp_sign, int bwl, TX_CLASS tx_class, int blk_pos, int coeff_sign, int32_t rate_zero[TOTALSTATES], int32_t rate[2 * TOTALSTATES], int64_t dist[2 * TOTALSTATES]";
+    specialize qw/av1_get_rate_dist_lf_luma avx2/;
+    add_proto qw/void av1_get_rate_dist_lf_chroma/, "const struct LV_MAP_COEFF_COST *txb_costs, const struct prequant_t *pq, const uint8_t coeff_ctx[TOTALSTATES + 4], int diag_ctx, int dc_sign_ctx, int32_t *tmp_sign, int bwl, TX_CLASS tx_class, int plane, int blk_pos, int coeff_sign, int32_t rate_zero[TOTALSTATES], int32_t rate[2 * TOTALSTATES], int64_t dist[2 * TOTALSTATES]";
+    specialize qw/av1_get_rate_dist_lf_chroma avx2/;
     add_proto qw/void av1_update_states/, "struct tcq_node_t *decision, int scan_idx, struct tcq_ctx_t *tcq_ctx";
     specialize qw/av1_update_states avx2/;
     add_proto qw/void av1_init_lf_ctx/, "const uint8_t *lev, int scan_hi, int bwl, struct tcq_lf_ctx_t *lf_ctx";
diff --git a/av1/encoder/trellis_quant.c b/av1/encoder/trellis_quant.c
index 29ece4b..f88d561 100644
--- a/av1/encoder/trellis_quant.c
+++ b/av1/encoder/trellis_quant.c
@@ -1018,61 +1018,22 @@
   const int sign = 0;
   const tran_low_t *absLevel = pq->absLevel;
   const int64_t *deltaDist = pq->deltaDist;
+
   for (int i = 0; i < TOTALSTATES; i++) {
-    int base_ctx = diag_ctx + (coeff_ctx[i] & 15);
     int dq = tcq_quant(i);
+    int a0 = dq;
+    int a1 = a0 + 2;
+    int base_ctx = diag_ctx + (coeff_ctx[i] & 15);
+    int cost0 = get_coeff_cost_def(absLevel[a0], coeff_ctx[i], diag_ctx, plane,
+                               txb_costs, dq, t_sign, sign);
+    int cost1 = get_coeff_cost_def(absLevel[a1], coeff_ctx[i], diag_ctx, plane,
+                               txb_costs, dq, t_sign, sign);
     rate_zero[i] = txb_costs->base_cost[base_ctx][dq][0];
+    rate[2 * i] = cost0;
+    rate[2 * i + 1] = cost1;
+    dist[2 * i] = deltaDist[a0];
+    dist[2 * i + 1] = deltaDist[a1];
   }
-  rate[0] = get_coeff_cost_def(absLevel[0], coeff_ctx[0], diag_ctx, plane,
-                               txb_costs, 0, t_sign, sign);
-  rate[1] = get_coeff_cost_def(absLevel[2], coeff_ctx[0], diag_ctx, plane,
-                               txb_costs, 0, t_sign, sign);
-  rate[2] = get_coeff_cost_def(absLevel[0], coeff_ctx[1], diag_ctx, plane,
-                               txb_costs, 0, t_sign, sign);
-  rate[3] = get_coeff_cost_def(absLevel[2], coeff_ctx[1], diag_ctx, plane,
-                               txb_costs, 0, t_sign, sign);
-  rate[4] = get_coeff_cost_def(absLevel[1], coeff_ctx[2], diag_ctx, plane,
-                               txb_costs, 1, t_sign, sign);
-  rate[5] = get_coeff_cost_def(absLevel[3], coeff_ctx[2], diag_ctx, plane,
-                               txb_costs, 1, t_sign, sign);
-  rate[6] = get_coeff_cost_def(absLevel[1], coeff_ctx[3], diag_ctx, plane,
-                               txb_costs, 1, t_sign, sign);
-  rate[7] = get_coeff_cost_def(absLevel[3], coeff_ctx[3], diag_ctx, plane,
-                               txb_costs, 1, t_sign, sign);
-  dist[0] = deltaDist[0];
-  dist[1] = deltaDist[2];
-  dist[2] = deltaDist[0];
-  dist[3] = deltaDist[2];
-  dist[4] = deltaDist[1];
-  dist[5] = deltaDist[3];
-  dist[6] = deltaDist[1];
-  dist[7] = deltaDist[3];
-#if MORESTATES
-  rate[8] = get_coeff_cost_def(absLevel[0], coeff_ctx[4], diag_ctx, plane,
-                               txb_costs, 0, t_sign, sign);
-  rate[9] = get_coeff_cost_def(absLevel[2], coeff_ctx[4], diag_ctx, plane,
-                               txb_costs, 0, t_sign, sign);
-  rate[10] = get_coeff_cost_def(absLevel[0], coeff_ctx[5], diag_ctx, plane,
-                                txb_costs, 0, t_sign, sign);
-  rate[11] = get_coeff_cost_def(absLevel[2], coeff_ctx[5], diag_ctx, plane,
-                                txb_costs, 0, t_sign, sign);
-  rate[12] = get_coeff_cost_def(absLevel[1], coeff_ctx[6], diag_ctx, plane,
-                                txb_costs, 1, t_sign, sign);
-  rate[13] = get_coeff_cost_def(absLevel[3], coeff_ctx[6], diag_ctx, plane,
-                                txb_costs, 1, t_sign, sign);
-  rate[14] = get_coeff_cost_def(absLevel[1], coeff_ctx[7], diag_ctx, plane,
-                                txb_costs, 1, t_sign, sign);
-  rate[15] = get_coeff_cost_def(absLevel[3], coeff_ctx[7], diag_ctx, plane,
-                                txb_costs, 1, t_sign, sign);
-  dist[8] = deltaDist[0];
-  dist[9] = deltaDist[2];
-  dist[10] = deltaDist[0];
-  dist[11] = deltaDist[2];
-  dist[12] = deltaDist[1];
-  dist[13] = deltaDist[3];
-  dist[14] = deltaDist[1];
-  dist[15] = deltaDist[3];
-#endif
 }
 
 void av1_get_rate_dist_def_chroma_c(const struct LV_MAP_COEFF_COST *txb_costs,
@@ -1084,230 +1045,89 @@
                                     int64_t dist[2 * TOTALSTATES]) {
   const tran_low_t *absLevel = pq->absLevel;
   const int64_t *deltaDist = pq->deltaDist;
+
   for (int i = 0; i < TOTALSTATES; i++) {
-    int base_ctx = diag_ctx + (coeff_ctx[i] & 15);
     int dq = tcq_quant(i);
+    int a0 = dq;
+    int a1 = a0 + 2;
+    int base_ctx = diag_ctx + (coeff_ctx[i] & 15);
+    int cost0 = get_coeff_cost_def(absLevel[a0], coeff_ctx[i], diag_ctx, plane,
+                               txb_costs, dq, t_sign, sign);
+    int cost1 = get_coeff_cost_def(absLevel[a1], coeff_ctx[i], diag_ctx, plane,
+                               txb_costs, dq, t_sign, sign);
     rate_zero[i] = txb_costs->base_cost_uv[base_ctx][dq][0];
+    rate[2 * i] = cost0;
+    rate[2 * i + 1] = cost1;
+    dist[2 * i] = deltaDist[a0];
+    dist[2 * i + 1] = deltaDist[a1];
   }
-  rate[0] = get_coeff_cost_def(absLevel[0], coeff_ctx[0], diag_ctx, plane,
-                               txb_costs, 0, t_sign, sign);
-  rate[1] = get_coeff_cost_def(absLevel[2], coeff_ctx[0], diag_ctx, plane,
-                               txb_costs, 0, t_sign, sign);
-  rate[2] = get_coeff_cost_def(absLevel[0], coeff_ctx[1], diag_ctx, plane,
-                               txb_costs, 0, t_sign, sign);
-  rate[3] = get_coeff_cost_def(absLevel[2], coeff_ctx[1], diag_ctx, plane,
-                               txb_costs, 0, t_sign, sign);
-  rate[4] = get_coeff_cost_def(absLevel[1], coeff_ctx[2], diag_ctx, plane,
-                               txb_costs, 1, t_sign, sign);
-  rate[5] = get_coeff_cost_def(absLevel[3], coeff_ctx[2], diag_ctx, plane,
-                               txb_costs, 1, t_sign, sign);
-  rate[6] = get_coeff_cost_def(absLevel[1], coeff_ctx[3], diag_ctx, plane,
-                               txb_costs, 1, t_sign, sign);
-  rate[7] = get_coeff_cost_def(absLevel[3], coeff_ctx[3], diag_ctx, plane,
-                               txb_costs, 1, t_sign, sign);
-  dist[0] = deltaDist[0];
-  dist[1] = deltaDist[2];
-  dist[2] = deltaDist[0];
-  dist[3] = deltaDist[2];
-  dist[4] = deltaDist[1];
-  dist[5] = deltaDist[3];
-  dist[6] = deltaDist[1];
-  dist[7] = deltaDist[3];
-#if MORESTATES
-  rate[8] = get_coeff_cost_def(absLevel[0], coeff_ctx[4], diag_ctx, plane,
-                               txb_costs, 0, t_sign, sign);
-  rate[9] = get_coeff_cost_def(absLevel[2], coeff_ctx[4], diag_ctx, plane,
-                               txb_costs, 0, t_sign, sign);
-  rate[10] = get_coeff_cost_def(absLevel[0], coeff_ctx[5], diag_ctx, plane,
-                                txb_costs, 0, t_sign, sign);
-  rate[11] = get_coeff_cost_def(absLevel[2], coeff_ctx[5], diag_ctx, plane,
-                                txb_costs, 0, t_sign, sign);
-  rate[12] = get_coeff_cost_def(absLevel[1], coeff_ctx[6], diag_ctx, plane,
-                                txb_costs, 1, t_sign, sign);
-  rate[13] = get_coeff_cost_def(absLevel[3], coeff_ctx[6], diag_ctx, plane,
-                                txb_costs, 1, t_sign, sign);
-  rate[14] = get_coeff_cost_def(absLevel[1], coeff_ctx[7], diag_ctx, plane,
-                                txb_costs, 1, t_sign, sign);
-  rate[15] = get_coeff_cost_def(absLevel[3], coeff_ctx[7], diag_ctx, plane,
-                                txb_costs, 1, t_sign, sign);
-  dist[8] = deltaDist[0];
-  dist[9] = deltaDist[2];
-  dist[10] = deltaDist[0];
-  dist[11] = deltaDist[2];
-  dist[12] = deltaDist[1];
-  dist[13] = deltaDist[3];
-  dist[14] = deltaDist[1];
-  dist[15] = deltaDist[3];
-#endif
-#if 0
-  static int n = 0;
-  int s_rate_zero[TOTALSTATES];
-  int s_rate[2 * TOTALSTATES];
-  int64_t s_dist[2 * TOTALSTATES];
-  av1_get_rate_dist_def_chroma_avx2(txb_costs, pq, coeff_ctx, diag_ctx, plane, t_sign, sign,
-                                    s_rate_zero, s_rate, s_dist);
-  int ok = 1;
-  for (int i = 0; i < TOTALSTATES; i++) {
-    int chk = s_rate_zero[i] == rate_zero[i];
-    if (!chk) {
-      printf("CHK i %d rate_zero %d %d\n", i, s_rate_zero[i], rate_zero[i]);
-    }
-    ok &= chk;
-  }
-  for (int i = 0; i < 2*TOTALSTATES; i++) {
-    int chk = s_rate[i] == rate[i];
-    if (!chk) {
-      printf("CHK i %d rate %d %d\n", i, s_rate[i], rate[i]);
-    }
-    ok &= chk;
-  }
-  for (int i = 0; i < 2*TOTALSTATES; i++) {
-    int chk = s_dist[i] == dist[i];
-    if (!chk) {
-      printf("CHK i %d dist %ld %ld\n", i, s_dist[i], dist[i]);
-    }
-    ok &= chk;
-  }
-  n++;
-  if (!ok) {
-    printf("plane %d t_sign %d sign %d\n", plane, t_sign, sign);
-    exit(1);
-  }
-#endif
 }
 
-void av1_get_rate_dist_lf_c(const struct LV_MAP_COEFF_COST *txb_costs,
-                            const struct prequant_t *pq,
-                            const uint8_t coeff_ctx[TOTALSTATES + 4],
-                            int diag_ctx, int dc_sign_ctx, int32_t *tmp_sign,
-                            int bwl, TX_CLASS tx_class, int plane, int blk_pos,
-                            int coeff_sign, int32_t rate_zero[TOTALSTATES],
-                            int32_t rate[2 * TOTALSTATES],
-                            int64_t dist[2 * TOTALSTATES]) {
+void av1_get_rate_dist_lf_luma_c(const struct LV_MAP_COEFF_COST *txb_costs,
+                                 const struct prequant_t *pq,
+                                 const uint8_t coeff_ctx[TOTALSTATES + 4],
+                                 int diag_ctx, int dc_sign_ctx, int32_t *tmp_sign,
+                                 int bwl, TX_CLASS tx_class, int blk_pos,
+                                 int coeff_sign, int32_t rate_zero[TOTALSTATES],
+                                 int32_t rate[2 * TOTALSTATES],
+                                 int64_t dist[2 * TOTALSTATES]) {
+  const tran_low_t *absLevel = pq->absLevel;
+  const int64_t *deltaDist = pq->deltaDist;
+  uint8_t base_ctx[TOTALSTATES];
+  uint8_t mid_ctx[TOTALSTATES];
+  int plane = 0;
+
+  for (int i = 0; i < TOTALSTATES; i++) {
+    int dq = tcq_quant(i);
+    int a0 = dq;
+    int a1 = a0 + 2;
+    base_ctx[i] = (coeff_ctx[i] & 15) + diag_ctx;
+    mid_ctx[i] = coeff_ctx[i] >> 4;
+    int cost0 = get_coeff_cost(blk_pos, absLevel[a0], coeff_sign, base_ctx[i],
+                               mid_ctx[i], dc_sign_ctx, txb_costs, bwl, tx_class,
+                               tmp_sign, plane, 1, dq);
+    int cost1 = get_coeff_cost(blk_pos, absLevel[a1], coeff_sign, base_ctx[i],
+                               mid_ctx[i], dc_sign_ctx, txb_costs, bwl, tx_class,
+                               tmp_sign, plane, 1, dq);
+    rate_zero[i] = txb_costs->base_lf_cost[base_ctx[i]][dq][0];
+    rate[2 * i] = cost0;
+    rate[2 * i + 1] = cost1;
+    dist[2 * i] = deltaDist[a0];
+    dist[2 * i + 1] = deltaDist[a1];
+  }
+}
+
+void av1_get_rate_dist_lf_chroma_c(const struct LV_MAP_COEFF_COST *txb_costs,
+                                   const struct prequant_t *pq,
+                                   const uint8_t coeff_ctx[TOTALSTATES + 4],
+                                   int diag_ctx, int dc_sign_ctx, int32_t *tmp_sign,
+                                   int bwl, TX_CLASS tx_class, int plane, int blk_pos,
+                                   int coeff_sign, int32_t rate_zero[TOTALSTATES],
+                                   int32_t rate[2 * TOTALSTATES],
+                                   int64_t dist[2 * TOTALSTATES]) {
   const tran_low_t *absLevel = pq->absLevel;
   const int64_t *deltaDist = pq->deltaDist;
   uint8_t base_ctx[TOTALSTATES];
   uint8_t mid_ctx[TOTALSTATES];
 
   for (int i = 0; i < TOTALSTATES; i++) {
+    int dq = tcq_quant(i);
+    int a0 = dq;
+    int a1 = a0 + 2;
     base_ctx[i] = (coeff_ctx[i] & 15) + diag_ctx;
     mid_ctx[i] = coeff_ctx[i] >> 4;
+    int cost0 = get_coeff_cost(blk_pos, absLevel[a0], coeff_sign, base_ctx[i],
+                               mid_ctx[i], dc_sign_ctx, txb_costs, bwl, tx_class,
+                               tmp_sign, plane, 1, dq);
+    int cost1 = get_coeff_cost(blk_pos, absLevel[a1], coeff_sign, base_ctx[i],
+                               mid_ctx[i], dc_sign_ctx, txb_costs, bwl, tx_class,
+                               tmp_sign, plane, 1, dq);
+    rate_zero[i] = txb_costs->base_lf_cost_uv[base_ctx[i]][dq][0];
+    rate[2 * i] = cost0;
+    rate[2 * i + 1] = cost1;
+    dist[2 * i] = deltaDist[a0];
+    dist[2 * i + 1] = deltaDist[a1];
   }
-
-  // calculate RDcost
-  for (int i = 0; i < TOTALSTATES; i++) {
-    int dq = tcq_quant(i);
-    const int(*base_lf_cost_ptr)[DQ_CTXS][LF_BASE_SYMBOLS * 2] =
-        plane > 0 ? txb_costs->base_lf_cost_uv : txb_costs->base_lf_cost;
-    rate_zero[i] = base_lf_cost_ptr[base_ctx[i]][dq][0];
-  }
-
-  rate[0] = get_coeff_cost(blk_pos, absLevel[0], coeff_sign, base_ctx[0],
-                           mid_ctx[0], dc_sign_ctx, txb_costs, bwl, tx_class,
-                           tmp_sign, plane, 1, 0);
-  rate[1] = get_coeff_cost(blk_pos, absLevel[2], coeff_sign, base_ctx[0],
-                           mid_ctx[0], dc_sign_ctx, txb_costs, bwl, tx_class,
-                           tmp_sign, plane, 1, 0);
-  rate[2] = get_coeff_cost(blk_pos, absLevel[0], coeff_sign, base_ctx[1],
-                           mid_ctx[1], dc_sign_ctx, txb_costs, bwl, tx_class,
-                           tmp_sign, plane, 1, 0);
-  rate[3] = get_coeff_cost(blk_pos, absLevel[2], coeff_sign, base_ctx[1],
-                           mid_ctx[1], dc_sign_ctx, txb_costs, bwl, tx_class,
-                           tmp_sign, plane, 1, 0);
-  rate[4] = get_coeff_cost(blk_pos, absLevel[1], coeff_sign, base_ctx[2],
-                           mid_ctx[2], dc_sign_ctx, txb_costs, bwl, tx_class,
-                           tmp_sign, plane, 1, 1);
-  rate[5] = get_coeff_cost(blk_pos, absLevel[3], coeff_sign, base_ctx[2],
-                           mid_ctx[2], dc_sign_ctx, txb_costs, bwl, tx_class,
-                           tmp_sign, plane, 1, 1);
-  rate[6] = get_coeff_cost(blk_pos, absLevel[1], coeff_sign, base_ctx[3],
-                           mid_ctx[3], dc_sign_ctx, txb_costs, bwl, tx_class,
-                           tmp_sign, plane, 1, 1);
-  rate[7] = get_coeff_cost(blk_pos, absLevel[3], coeff_sign, base_ctx[3],
-                           mid_ctx[3], dc_sign_ctx, txb_costs, bwl, tx_class,
-                           tmp_sign, plane, 1, 1);
-#if MORESTATES
-  rate[8] = get_coeff_cost(blk_pos, absLevel[0], coeff_sign, base_ctx[4],
-                           mid_ctx[4], dc_sign_ctx, txb_costs, bwl, tx_class,
-                           tmp_sign, plane, 1, 0);
-  rate[9] = get_coeff_cost(blk_pos, absLevel[2], coeff_sign, base_ctx[4],
-                           mid_ctx[4], dc_sign_ctx, txb_costs, bwl, tx_class,
-                           tmp_sign, plane, 1, 0);
-  rate[10] = get_coeff_cost(blk_pos, absLevel[0], coeff_sign, base_ctx[5],
-                            mid_ctx[5], dc_sign_ctx, txb_costs, bwl, tx_class,
-                            tmp_sign, plane, 1, 0);
-  rate[11] = get_coeff_cost(blk_pos, absLevel[2], coeff_sign, base_ctx[5],
-                            mid_ctx[5], dc_sign_ctx, txb_costs, bwl, tx_class,
-                            tmp_sign, plane, 1, 0);
-  rate[12] = get_coeff_cost(blk_pos, absLevel[1], coeff_sign, base_ctx[6],
-                            mid_ctx[6], dc_sign_ctx, txb_costs, bwl, tx_class,
-                            tmp_sign, plane, 1, 1);
-  rate[13] = get_coeff_cost(blk_pos, absLevel[3], coeff_sign, base_ctx[6],
-                            mid_ctx[6], dc_sign_ctx, txb_costs, bwl, tx_class,
-                            tmp_sign, plane, 1, 1);
-  rate[14] = get_coeff_cost(blk_pos, absLevel[1], coeff_sign, base_ctx[7],
-                            mid_ctx[7], dc_sign_ctx, txb_costs, bwl, tx_class,
-                            tmp_sign, plane, 1, 1);
-  rate[15] = get_coeff_cost(blk_pos, absLevel[3], coeff_sign, base_ctx[7],
-                            mid_ctx[7], dc_sign_ctx, txb_costs, bwl, tx_class,
-                            tmp_sign, plane, 1, 1);
-#endif
-  dist[0] = deltaDist[0];
-  dist[1] = deltaDist[2];
-  dist[2] = deltaDist[0];
-  dist[3] = deltaDist[2];
-  dist[4] = deltaDist[1];
-  dist[5] = deltaDist[3];
-  dist[6] = deltaDist[1];
-  dist[7] = deltaDist[3];
-#if MORESTATES
-  dist[8] = deltaDist[0];
-  dist[9] = deltaDist[2];
-  dist[10] = deltaDist[0];
-  dist[11] = deltaDist[2];
-  dist[12] = deltaDist[1];
-  dist[13] = deltaDist[3];
-  dist[14] = deltaDist[1];
-  dist[15] = deltaDist[3];
-#endif
-#if 0
-  static int n = 0;
-  int t_sign = tmp_sign[blk_pos];
-  int sign = coeff_sign;
-  int s_rate_zero[TOTALSTATES];
-  int s_rate[2 * TOTALSTATES];
-  int64_t s_dist[2 * TOTALSTATES];
-  av1_get_rate_dist_lf_avx2(txb_costs, pq, coeff_ctx, diag_ctx, dc_sign_ctx,
-                            tmp_sign, bwl, tx_class, plane, blk_pos, coeff_sign,
-                            s_rate_zero, s_rate, s_dist);
-  int ok = 1;
-  for (int i = 0; i < TOTALSTATES; i++) {
-    int chk = s_rate_zero[i] == rate_zero[i];
-    if (!chk) {
-      printf("CHK i %d rate_zero %d %d\n", i, s_rate_zero[i], rate_zero[i]);
-    }
-    ok &= chk;
-  }
-  for (int i = 0; i < 2*TOTALSTATES; i++) {
-    int chk = s_rate[i] == rate[i];
-    if (!chk || n == 0) {
-      printf("CHK i %d rate %d %d\n", i, s_rate[i], rate[i]);
-    }
-    ok &= chk;
-  }
-  for (int i = 0; i < 2*TOTALSTATES; i++) {
-    int chk = s_dist[i] == dist[i];
-    if (!chk) {
-      printf("CHK i %d dist %ld %ld\n", i, s_dist[i], dist[i]);
-    }
-    ok &= chk;
-  }
-  n++;
-  if (!ok) {
-    printf("plane %d t_sign %d sign %d\n", plane, t_sign, sign);
-    exit(1);
-  }
-#endif
 }
 
 void av1_calc_diag_ctx_c(int scan_hi, int scan_lo, int bwl,
@@ -1521,7 +1341,8 @@
 }
 
 void av1_update_lf_ctx_c(const struct tcq_node_t *decision,
-                         struct tcq_lf_ctx_t *lf_ctx) {
+                         struct tcq_lf_ctx_t *lf_ctx)
+{
   tcq_lf_ctx_t save[TOTALSTATES];
   memcpy(save, lf_ctx, sizeof(tcq_lf_ctx_t) * TOTALSTATES);
 
@@ -1589,9 +1410,9 @@
     int rate_zero[TOTALSTATES];
     int rate[2 * TOTALSTATES];
     int64_t dist[2 * TOTALSTATES];
-    av1_get_rate_dist_lf(txb_costs, &pqData, coeff_ctx, diag_ctx,
-                         txb_ctx->dc_sign_ctx, tmp_sign, bwl, tx_class, plane,
-                         blk_pos, coeff_sign, rate_zero, rate, dist);
+    av1_get_rate_dist_lf_luma(txb_costs, &pqData, coeff_ctx, diag_ctx,
+                              txb_ctx->dc_sign_ctx, tmp_sign, bwl, tx_class,
+                              blk_pos, coeff_sign, rate_zero, rate, dist);
 
     av1_decide_states(prd, dist, rate, rate_zero, &pqData, limits, rdmult,
                       decision);
@@ -1716,10 +1537,14 @@
     int32_t rate_zero[TOTALSTATES];
     int32_t rate[2 * TOTALSTATES];
     int64_t dist[2 * TOTALSTATES];
-    if (limits) {
-      av1_get_rate_dist_lf(txb_costs, &pqData, coeff_ctx, diag_ctx,
-                           txb_ctx->dc_sign_ctx, tmp_sign, bwl, tx_class, plane,
-                           blk_pos, coeff_sign, rate_zero, rate, dist);
+    if (limits && plane == 0) {
+      av1_get_rate_dist_lf_luma(txb_costs, &pqData, coeff_ctx, diag_ctx,
+                                txb_ctx->dc_sign_ctx, tmp_sign, bwl, tx_class,
+                                blk_pos, coeff_sign, rate_zero, rate, dist);
+    } else if (limits) {
+      av1_get_rate_dist_lf_chroma(txb_costs, &pqData, coeff_ctx, diag_ctx,
+                                  txb_ctx->dc_sign_ctx, tmp_sign, bwl, tx_class, plane,
+                                  blk_pos, coeff_sign, rate_zero, rate, dist);
     } else if (plane == 0) {
       av1_get_rate_dist_def_luma(txb_costs, &pqData, coeff_ctx, diag_ctx,
                                  rate_zero, rate, dist);
diff --git a/av1/encoder/x86/trellis_quant_avx2.c b/av1/encoder/x86/trellis_quant_avx2.c
index 64b35b5..209efa9 100644
--- a/av1/encoder/x86/trellis_quant_avx2.c
+++ b/av1/encoder/x86/trellis_quant_avx2.c
@@ -550,10 +550,8 @@
                             struct tcq_lf_ctx_t *lf_ctx) {
   __m256i upd_last_a;
   __m256i upd_last_b;
-#if MORESTATES
   __m256i upd_last_c;
   __m256i upd_last_d;
-#endif
   for (int st = 0; st < TOTALSTATES; st += 2) {
     int absLevel0 = decision[st].absLevel;
     int prevId0 = decision[st].prevId;
@@ -573,10 +571,8 @@
     upd1 = _mm_insert_epi8(upd1, AOMMIN(absLevel1, INT8_MAX), 0);
     __m256i upd01 = _mm256_castsi128_si256(upd0);
     upd01 = _mm256_inserti128_si256(upd01, upd1, 1);
-#if MORESTATES
     upd_last_d = upd_last_c;
     upd_last_c = upd_last_b;
-#endif
     upd_last_b = upd_last_a;
     upd_last_a = upd01;
   }
@@ -586,12 +582,126 @@
   _mm256_storeu_si256((__m256i *)lf_ctx[4].last, upd_last_b);
   _mm256_storeu_si256((__m256i *)lf_ctx[6].last, upd_last_a);
 #else
+  (void)upd_last_d;
+  (void)upd_last_c;
   _mm256_storeu_si256((__m256i *)lf_ctx[0].last, upd_last_b);
   _mm256_storeu_si256((__m256i *)lf_ctx[2].last, upd_last_a);
 #endif
 }
 
-void av1_get_rate_dist_lf_avx2(
+void av1_get_rate_dist_lf_luma_avx2(
+    const struct LV_MAP_COEFF_COST *txb_costs, const struct prequant_t *pq,
+    const uint8_t coeff_ctx[TOTALSTATES + 4], int diag_ctx, int dc_sign_ctx,
+    int32_t *tmp_sign, int bwl, TX_CLASS tx_class, int blk_pos,
+    int coeff_sign, int32_t rate_zero[TOTALSTATES],
+    int32_t rate[2 * TOTALSTATES], int64_t dist[2 * TOTALSTATES]) {
+#define Z -1
+  static const int8_t kShuf[2][32] = {
+    { 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15,
+      0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15 },
+    { 0, 8,  Z, Z, 1, 9,  Z, Z, 2, 10, Z, Z, 3, 11, Z, Z,
+      4, 12, Z, Z, 5, 13, Z, Z, 6, 14, Z, Z, 7, 15, Z, Z }
+  };
+  const uint16_t(*cost_low)[LF_BASE_SYMBOLS][LF_SIG_COEF_CONTEXTS] =
+      txb_costs->base_lf_cost_low;
+  const uint16_t(*cost_low_tbl)[LF_SIG_COEF_CONTEXTS][DQ_CTXS][2] =
+      txb_costs->base_lf_cost_low_tbl;
+  const tran_low_t *absLevel = pq->absLevel;
+  const int64_t *deltaDist = pq->deltaDist;
+  const int plane = 0;
+
+  // Copy distortion stats.
+  __m256i delta_dist = _mm256_lddqu_si256((__m256i *)deltaDist);
+  __m256i dist02 = _mm256_permute4x64_epi64(delta_dist, 0x88);
+  __m256i dist13 = _mm256_permute4x64_epi64(delta_dist, 0xDD);
+  _mm256_storeu_si256((__m256i *)&dist[0], dist02);
+  _mm256_storeu_si256((__m256i *)&dist[4], dist13);
+#if MORESTATES
+  _mm256_storeu_si256((__m256i *)&dist[8], dist02);
+  _mm256_storeu_si256((__m256i *)&dist[12], dist13);
+#endif
+
+  // Calc zero coeff costs.
+  __m256i cost_zero_dq0 =
+      _mm256_lddqu_si256((__m256i *)&cost_low[0][0][diag_ctx]);
+  __m256i cost_zero_dq1 =
+      _mm256_lddqu_si256((__m256i *)&cost_low[1][0][diag_ctx]);
+  __m256i shuf = _mm256_lddqu_si256((__m256i *)kShuf[0]);
+  cost_zero_dq0 = _mm256_shuffle_epi8(cost_zero_dq0, shuf);
+  cost_zero_dq1 = _mm256_shuffle_epi8(cost_zero_dq1, shuf);
+  __m256i cost_dq0 = _mm256_permute4x64_epi64(cost_zero_dq0, 0xD8);
+  __m256i cost_dq1 = _mm256_permute4x64_epi64(cost_zero_dq1, 0xD8);
+  __m256i ctx = _mm256_castsi128_si256(_mm_loadu_si64(coeff_ctx));
+  __m256i fifteen = _mm256_set1_epi8(15);
+  __m256i base_ctx = _mm256_and_si256(ctx, fifteen);
+  base_ctx = _mm256_permute4x64_epi64(base_ctx, 0);
+  __m256i ratez_dq0 = _mm256_shuffle_epi8(cost_dq0, base_ctx);
+  __m256i ratez_dq1 = _mm256_shuffle_epi8(cost_dq1, base_ctx);
+  __m256i ratez = _mm256_blend_epi16(ratez_dq0, ratez_dq1, 0xAA);
+  ratez = _mm256_permute4x64_epi64(ratez, 0x88);
+  __m256i shuf1 = _mm256_lddqu_si256((__m256i *)kShuf[1]);
+  ratez = _mm256_shuffle_epi8(ratez, shuf1);
+#if MORESTATES
+  _mm256_storeu_si256((__m256i *)&rate_zero[0], ratez);
+#else
+  _mm_storeu_si128((__m128i *)&rate_zero[0], _mm256_castsi256_si128(ratez));
+#endif
+
+  // Calc coeff_base rate.
+  int idx = AOMMIN(pq->qIdx - 1, 8);
+  for (int i = 0; i < TOTALSTATES / 4; i++) {
+    int j = 4 * i;
+    int ctx0 = diag_ctx + (coeff_ctx[j + 0] & 15);
+    int ctx1 = diag_ctx + (coeff_ctx[j + 1] & 15);
+    int ctx2 = diag_ctx + (coeff_ctx[j + 2] & 15);
+    int ctx3 = diag_ctx + (coeff_ctx[j + 3] & 15);
+    __m128i rate_01 = _mm_loadu_si64(&cost_low_tbl[idx][ctx0][0]);
+    __m128i rate_23 = _mm_loadu_si64(&cost_low_tbl[idx][ctx1][0]);
+    __m128i rate_45 = _mm_loadu_si64(&cost_low_tbl[idx][ctx2][1]);
+    __m128i rate_67 = _mm_loadu_si64(&cost_low_tbl[idx][ctx3][1]);
+    __m128i rate_0123 = _mm_unpacklo_epi32(rate_01, rate_23);
+    __m128i rate_4567 = _mm_unpacklo_epi32(rate_45, rate_67);
+    __m128i c_zero = _mm_setzero_si128();
+    rate_0123 = _mm_unpacklo_epi16(rate_0123, c_zero);
+    rate_4567 = _mm_unpacklo_epi16(rate_4567, c_zero);
+    _mm_storeu_si128((__m128i *)&rate[8 * i], rate_0123);
+    _mm_storeu_si128((__m128i *)&rate[8 * i + 4], rate_4567);
+  }
+
+  const int row = blk_pos >> bwl;
+  const int col = blk_pos - (row << bwl);
+  const bool dc_2dtx = (blk_pos == 0);
+  const bool dc_hor = (col == 0) && tx_class == TX_CLASS_HORIZ;
+  const bool dc_ver = (row == 0) && tx_class == TX_CLASS_VERT;
+  const bool is_dc_coeff = dc_2dtx || dc_hor || dc_ver;
+  if (is_dc_coeff) {
+    for (int i = 0; i < TOTALSTATES; i++) {
+      int a0 = i & 2 ? 1 : 0;
+      int a1 = a0 + 2;
+      int mid_cost0 =
+          get_mid_cost_lf_dc(blk_pos, absLevel[a0], coeff_sign, coeff_ctx[i],
+                             dc_sign_ctx, txb_costs, tmp_sign, plane);
+      int mid_cost1 =
+          get_mid_cost_lf_dc(blk_pos, absLevel[a1], coeff_sign, coeff_ctx[i],
+                             dc_sign_ctx, txb_costs, tmp_sign, plane);
+      rate[2 * i] += mid_cost0;
+      rate[2 * i + 1] += mid_cost1;
+    }
+  } else if (idx > 4) {
+    for (int i = 0; i < TOTALSTATES; i++) {
+      int a0 = i & 2 ? 1 : 0;
+      int a1 = a0 + 2;
+      int mid_cost0 =
+          get_mid_cost_lf(absLevel[a0], coeff_ctx[i], txb_costs, plane);
+      int mid_cost1 =
+          get_mid_cost_lf(absLevel[a1], coeff_ctx[i], txb_costs, plane);
+      rate[2 * i] += mid_cost0;
+      rate[2 * i + 1] += mid_cost1;
+    }
+  }
+}
+
+void av1_get_rate_dist_lf_chroma_avx2(
     const struct LV_MAP_COEFF_COST *txb_costs, const struct prequant_t *pq,
     const uint8_t coeff_ctx[TOTALSTATES + 4], int diag_ctx, int dc_sign_ctx,
     int32_t *tmp_sign, int bwl, TX_CLASS tx_class, int plane, int blk_pos,