[tcq] Vectorize av1_get_rate_dist_* functions.

Saves a few % of EncTime.
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index 14540d4..0ba73cd 100644
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -325,8 +325,12 @@
     specialize qw/av1_pre_quant avx2/;
     add_proto qw/void av1_calc_diag_ctx/, "int scan_hi, int scan_lo, int bwl, const uint8_t *prev_levels, const int16_t* scan, uint8_t *ctx";
     specialize qw/av1_calc_diag_ctx avx2/;
-    add_proto qw/void av1_get_rate_dist_def/, "const struct LV_MAP_COEFF_COST* txb_costs, const struct prequant_t *pq, const uint8_t coeff_ctx[2 * TOTALSTATES], int diag_ctx, int plane, int32_t rate_zero[TOTALSTATES], int32_t rate[2 * TOTALSTATES], int64_t dist[2 * TOTALSTATES]";
-    specialize qw/av1_get_rate_dist_def avx2/;
+    add_proto qw/void av1_get_rate_dist_def_luma/, "const struct LV_MAP_COEFF_COST* txb_costs, const struct prequant_t *pq, const uint8_t coeff_ctx[TOTALSTATES + 4], int diag_ctx, int32_t rate_zero[TOTALSTATES], int32_t rate[2 * TOTALSTATES], int64_t dist[2 * TOTALSTATES]";
+    specialize qw/av1_get_rate_dist_def_luma avx2/;
+    add_proto qw/void av1_get_rate_dist_def_chroma/, "const struct LV_MAP_COEFF_COST* txb_costs, const struct prequant_t *pq, const uint8_t coeff_ctx[TOTALSTATES + 4], int diag_ctx, int plane, int t_sign, int sign, int32_t rate_zero[TOTALSTATES], int32_t rate[2 * TOTALSTATES], int64_t dist[2 * TOTALSTATES]";
+    specialize qw/av1_get_rate_dist_def_chroma avx2/;
+    add_proto qw/void av1_get_rate_dist_lf/, "const struct LV_MAP_COEFF_COST *txb_costs, const struct prequant_t *pq, const uint8_t coeff_ctx[TOTALSTATES + 4], int diag_ctx, int dc_sign_ctx, int32_t *tmp_sign, int bwl, TX_CLASS tx_class, int plane, int blk_pos, int coeff_sign, int32_t rate_zero[TOTALSTATES], int32_t rate[2 * TOTALSTATES], int64_t dist[2 * TOTALSTATES]";
+    specialize qw/av1_get_rate_dist_lf avx2/;
     add_proto qw/void av1_update_states/, "struct tcq_node_t *decision, int scan_idx, struct tcq_ctx_t *tcq_ctx";
     specialize qw/av1_update_states avx2/;
   }
diff --git a/av1/common/quant_common.h b/av1/common/quant_common.h
index a23a142..a46a1f9 100644
--- a/av1/common/quant_common.h
+++ b/av1/common/quant_common.h
@@ -67,7 +67,6 @@
 struct macroblockd;
 
 #if CONFIG_DQ
-int tcq_parity(int absLevel, int limits);
 bool tcq_quant(const int state);
 int tcq_parity(int absLevel, int limits);
 int tcq_next_state(const int curState, const int absLevel, const int limits);
diff --git a/av1/encoder/block.h b/av1/encoder/block.h
index 446bce5..799bfdd 100644
--- a/av1/encoder/block.h
+++ b/av1/encoder/block.h
@@ -227,8 +227,14 @@
                [8];
 #if CONFIG_DQ
   //! Quick access to base costs 0-3 for optimized access.
-  int base_cost_low[DQ_CTXS][4][SIG_COEF_CONTEXTS];
+  int32_t base_cost_low[DQ_CTXS][4][SIG_COEF_CONTEXTS];
+  int32_t base_cost_uv_low[DQ_CTXS][4][SIG_COEF_CONTEXTS];
   uint16_t base_cost_low_tbl[5][SIG_COEF_CONTEXTS][DQ_CTXS][2];
+  uint16_t base_cost_uv_low_tbl[5][SIG_COEF_CONTEXTS][DQ_CTXS][2];
+  uint16_t base_lf_cost_low[DQ_CTXS][LF_BASE_SYMBOLS][LF_SIG_COEF_CONTEXTS];
+  uint16_t base_lf_cost_uv_low[DQ_CTXS][LF_BASE_SYMBOLS][LF_SIG_COEF_CONTEXTS];
+  uint16_t base_lf_cost_low_tbl[9][LF_SIG_COEF_CONTEXTS][DQ_CTXS][2];
+  uint16_t base_lf_cost_uv_low_tbl[9][LF_SIG_COEF_CONTEXTS][DQ_CTXS][2];
 #endif
 #if CONFIG_DQ && !CONFIG_LCCHROMA
   //! Cost for encoding the base level of a low-frequency coefficient
diff --git a/av1/encoder/rd.c b/av1/encoder/rd.c
index 6363d09..192b398 100644
--- a/av1/encoder/rd.c
+++ b/av1/encoder/rd.c
@@ -1206,22 +1206,43 @@
           for (int ctx = 0; ctx < SIG_COEF_CONTEXTS; ++ctx) {
             pcost->base_cost_low[dq][lev][ctx] = pcost->base_cost[ctx][dq][lev];
           }
+          for (int ctx = 0; ctx < SIG_COEF_CONTEXTS_UV; ++ctx) {
+            pcost->base_cost_uv_low[dq][lev][ctx] =
+                pcost->base_cost_uv[ctx][dq][lev];
+          }
+        }
+      }
+      // Rearrange costs into base_lf_cost_low[] array for quicker access.
+      for (int dq = 0; dq < DQ_CTXS; dq++) {
+        for (int lev = 0; lev < LF_BASE_SYMBOLS; lev++) {
+          for (int ctx = 0; ctx < LF_SIG_COEF_CONTEXTS; ++ctx) {
+            pcost->base_lf_cost_low[dq][lev][ctx] =
+                pcost->base_lf_cost[ctx][dq][lev];
+          }
+          for (int ctx = 0; ctx < LF_SIG_COEF_CONTEXTS_UV; ++ctx) {
+            pcost->base_lf_cost_uv_low[dq][lev][ctx] =
+                pcost->base_lf_cost_uv[ctx][dq][lev];
+          }
         }
       }
       // Precompute some base_costs for trellis, interleaved for quick access.
-      static const uint8_t trel_abslev[5][4] = {
+      static const uint8_t trel_abslev[9][4] = {
         { 2, 1, 1, 2 },  // qIdx = 1
         { 2, 3, 1, 2 },  // qIdx = 2
         { 2, 3, 3, 2 },  // qIdx = 3
-        { 2, 3, 3, 3 },  // qIdx = 4
-        { 3, 3, 3, 3 },  // qIdx = 5+
+        { 2, 3, 3, 4 },  // qIdx = 4
+        { 4, 3, 3, 4 },  // qIdx = 5
+        { 4, 5, 3, 4 },  // qIdx = 6
+        { 4, 5, 5, 4 },  // qIdx = 7
+        { 4, 5, 5, 6 },  // qIdx = 8
+        { 6, 5, 5, 6 },  // qIdx = 9
       };
       for (int idx = 0; idx < 5; idx++) {
         for (int ctx = 0; ctx < SIG_COEF_CONTEXTS; ++ctx) {
-          int a0 = trel_abslev[idx][0];
-          int a1 = trel_abslev[idx][1];
-          int a2 = trel_abslev[idx][2];
-          int a3 = trel_abslev[idx][3];
+          int a0 = AOMMIN(trel_abslev[idx][0], 3);
+          int a1 = AOMMIN(trel_abslev[idx][1], 3);
+          int a2 = AOMMIN(trel_abslev[idx][2], 3);
+          int a3 = AOMMIN(trel_abslev[idx][3], 3);
           // DQ0, absLev 0 / 2
           pcost->base_cost_low_tbl[idx][ctx][0][0] =
               pcost->base_cost[ctx][0][a0] + av1_cost_literal(1);
@@ -1234,6 +1255,62 @@
               pcost->base_cost[ctx][1][a3] + av1_cost_literal(1);
         }
       }
+      for (int idx = 0; idx < 5; idx++) {
+        for (int ctx = 0; ctx < SIG_COEF_CONTEXTS_UV; ++ctx) {
+          int a0 = AOMMIN(trel_abslev[idx][0], 3);
+          int a1 = AOMMIN(trel_abslev[idx][1], 3);
+          int a2 = AOMMIN(trel_abslev[idx][2], 3);
+          int a3 = AOMMIN(trel_abslev[idx][3], 3);
+          // DQ0, uv, absLev 0 / 2
+          pcost->base_cost_uv_low_tbl[idx][ctx][0][0] =
+              pcost->base_cost_uv[ctx][0][a0] + av1_cost_literal(1);
+          pcost->base_cost_uv_low_tbl[idx][ctx][0][1] =
+              pcost->base_cost_uv[ctx][0][a2] + av1_cost_literal(1);
+          // DQ1, uv,absLev 1 / 3
+          pcost->base_cost_uv_low_tbl[idx][ctx][1][0] =
+              pcost->base_cost_uv[ctx][1][a1] + av1_cost_literal(1);
+          pcost->base_cost_uv_low_tbl[idx][ctx][1][1] =
+              pcost->base_cost_uv[ctx][1][a3] + av1_cost_literal(1);
+        }
+      }
+      for (int idx = 0; idx < 9; idx++) {
+        for (int ctx = 0; ctx < LF_SIG_COEF_CONTEXTS; ++ctx) {
+          int max = LF_BASE_SYMBOLS - 1;
+          int a0 = AOMMIN(trel_abslev[idx][0], max);
+          int a1 = AOMMIN(trel_abslev[idx][1], max);
+          int a2 = AOMMIN(trel_abslev[idx][2], max);
+          int a3 = AOMMIN(trel_abslev[idx][3], max);
+          // DQ0, absLev 0 / 2
+          pcost->base_lf_cost_low_tbl[idx][ctx][0][0] =
+              pcost->base_lf_cost[ctx][0][a0] + av1_cost_literal(1);
+          pcost->base_lf_cost_low_tbl[idx][ctx][0][1] =
+              pcost->base_lf_cost[ctx][0][a2] + av1_cost_literal(1);
+          // DQ1, absLev 1 / 3
+          pcost->base_lf_cost_low_tbl[idx][ctx][1][0] =
+              pcost->base_lf_cost[ctx][1][a1] + av1_cost_literal(1);
+          pcost->base_lf_cost_low_tbl[idx][ctx][1][1] =
+              pcost->base_lf_cost[ctx][1][a3] + av1_cost_literal(1);
+        }
+      }
+      for (int idx = 0; idx < 9; idx++) {
+        for (int ctx = 0; ctx < LF_SIG_COEF_CONTEXTS_UV; ++ctx) {
+          int max = LF_BASE_SYMBOLS - 1;
+          int a0 = AOMMIN(trel_abslev[idx][0], max);
+          int a1 = AOMMIN(trel_abslev[idx][1], max);
+          int a2 = AOMMIN(trel_abslev[idx][2], max);
+          int a3 = AOMMIN(trel_abslev[idx][3], max);
+          // DQ0, absLev 0 / 2
+          pcost->base_lf_cost_uv_low_tbl[idx][ctx][0][0] =
+              pcost->base_lf_cost_uv[ctx][0][a0] + av1_cost_literal(1);
+          pcost->base_lf_cost_uv_low_tbl[idx][ctx][0][1] =
+              pcost->base_lf_cost_uv[ctx][0][a2] + av1_cost_literal(1);
+          // DQ1, absLev 1 / 3
+          pcost->base_lf_cost_uv_low_tbl[idx][ctx][1][0] =
+              pcost->base_lf_cost_uv[ctx][1][a1] + av1_cost_literal(1);
+          pcost->base_lf_cost_uv_low_tbl[idx][ctx][1][1] =
+              pcost->base_lf_cost_uv[ctx][1][a3] + av1_cost_literal(1);
+        }
+      }
 #else
       for (int ctx = 0; ctx < LF_SIG_COEF_CONTEXTS_UV; ++ctx) {
         av1_cost_tokens_from_cdf(pcost->base_lf_cost_uv[ctx],
diff --git a/av1/encoder/trellis_quant.c b/av1/encoder/trellis_quant.c
index 2fffcbd..aed5c9c 100644
--- a/av1/encoder/trellis_quant.c
+++ b/av1/encoder/trellis_quant.c
@@ -421,15 +421,27 @@
 }
 
 #if CONFIG_CONTEXT_DERIVATION && CONFIG_LCCHROMA && CONFIG_IMPROVEIDTX_CTXS
-static int get_coeff_cost_def(tran_low_t abs_qc, int coeff_ctx, int diag_ctx,
-                              const LV_MAP_COEFF_COST *txb_costs, int dq) {
+static INLINE int get_coeff_cost_def(tran_low_t abs_qc, int coeff_ctx,
+                                     int diag_ctx, int plane,
+                                     const LV_MAP_COEFF_COST *txb_costs, int dq,
+                                     int t_sign, int sign) {
   int base_ctx = diag_ctx + (coeff_ctx & 15);
   int mid_ctx = coeff_ctx >> 4;
-  int cost = txb_costs->base_cost[base_ctx][dq][AOMMIN(abs_qc, 3)];
+  const int(*base_cost_ptr)[DQ_CTXS][8] =
+      plane > 0 ? txb_costs->base_cost_uv : txb_costs->base_cost;
+  int cost = base_cost_ptr[base_ctx][dq][AOMMIN(abs_qc, 3)];
   if (abs_qc != 0) {
-    cost += av1_cost_literal(1);
-    if (abs_qc > NUM_BASE_LEVELS)
-      cost += get_br_cost(abs_qc, txb_costs->lps_cost[mid_ctx]);
+    if (plane == AOM_PLANE_V)
+      cost += txb_costs->v_ac_sign_cost[t_sign][sign];
+    else
+      cost += av1_cost_literal(1);
+    if (abs_qc > NUM_BASE_LEVELS) {
+      if (plane == 0) {
+        cost += get_br_cost(abs_qc, txb_costs->lps_cost[mid_ctx]);
+      } else {
+        cost += get_br_cost(abs_qc, txb_costs->lps_cost_uv[mid_ctx]);
+      }
+    }
   }
   return cost;
 }
@@ -995,38 +1007,38 @@
                  scan, scan_pos, scan_pos, bwl, sharpness);
 }
 
-void av1_get_rate_dist_def_c(const struct LV_MAP_COEFF_COST *txb_costs,
-                             const struct prequant_t *pq,
-                             const uint8_t coeff_ctx[2 * TOTALSTATES],
-                             int diag_ctx, int plane,
-                             int32_t rate_zero[TOTALSTATES],
-                             int32_t rate[2 * TOTALSTATES],
-                             int64_t dist[2 * TOTALSTATES]) {
+void av1_get_rate_dist_def_luma_c(const struct LV_MAP_COEFF_COST *txb_costs,
+                                  const struct prequant_t *pq,
+                                  const uint8_t coeff_ctx[TOTALSTATES + 4],
+                                  int diag_ctx, int32_t rate_zero[TOTALSTATES],
+                                  int32_t rate[2 * TOTALSTATES],
+                                  int64_t dist[2 * TOTALSTATES]) {
+  const int plane = 0;
+  const int t_sign = 0;
+  const int sign = 0;
   const tran_low_t *absLevel = pq->absLevel;
   const int64_t *deltaDist = pq->deltaDist;
   for (int i = 0; i < TOTALSTATES; i++) {
     int base_ctx = diag_ctx + (coeff_ctx[i] & 15);
     int dq = tcq_quant(i);
-    const int(*base_cost_ptr)[DQ_CTXS][8] =
-        plane > 0 ? txb_costs->base_cost_uv : txb_costs->base_cost;
-    rate_zero[i] = base_cost_ptr[base_ctx][dq][0];
+    rate_zero[i] = txb_costs->base_cost[base_ctx][dq][0];
   }
-  rate[0] =
-      get_coeff_cost_def(absLevel[0], coeff_ctx[0], diag_ctx, txb_costs, 0);
-  rate[1] =
-      get_coeff_cost_def(absLevel[2], coeff_ctx[0], diag_ctx, txb_costs, 0);
-  rate[2] =
-      get_coeff_cost_def(absLevel[0], coeff_ctx[1], diag_ctx, txb_costs, 0);
-  rate[3] =
-      get_coeff_cost_def(absLevel[2], coeff_ctx[1], diag_ctx, txb_costs, 0);
-  rate[4] =
-      get_coeff_cost_def(absLevel[1], coeff_ctx[2], diag_ctx, txb_costs, 1);
-  rate[5] =
-      get_coeff_cost_def(absLevel[3], coeff_ctx[2], diag_ctx, txb_costs, 1);
-  rate[6] =
-      get_coeff_cost_def(absLevel[1], coeff_ctx[3], diag_ctx, txb_costs, 1);
-  rate[7] =
-      get_coeff_cost_def(absLevel[3], coeff_ctx[3], diag_ctx, txb_costs, 1);
+  rate[0] = get_coeff_cost_def(absLevel[0], coeff_ctx[0], diag_ctx, plane,
+                               txb_costs, 0, t_sign, sign);
+  rate[1] = get_coeff_cost_def(absLevel[2], coeff_ctx[0], diag_ctx, plane,
+                               txb_costs, 0, t_sign, sign);
+  rate[2] = get_coeff_cost_def(absLevel[0], coeff_ctx[1], diag_ctx, plane,
+                               txb_costs, 0, t_sign, sign);
+  rate[3] = get_coeff_cost_def(absLevel[2], coeff_ctx[1], diag_ctx, plane,
+                               txb_costs, 0, t_sign, sign);
+  rate[4] = get_coeff_cost_def(absLevel[1], coeff_ctx[2], diag_ctx, plane,
+                               txb_costs, 1, t_sign, sign);
+  rate[5] = get_coeff_cost_def(absLevel[3], coeff_ctx[2], diag_ctx, plane,
+                               txb_costs, 1, t_sign, sign);
+  rate[6] = get_coeff_cost_def(absLevel[1], coeff_ctx[3], diag_ctx, plane,
+                               txb_costs, 1, t_sign, sign);
+  rate[7] = get_coeff_cost_def(absLevel[3], coeff_ctx[3], diag_ctx, plane,
+                               txb_costs, 1, t_sign, sign);
   dist[0] = deltaDist[0];
   dist[1] = deltaDist[2];
   dist[2] = deltaDist[0];
@@ -1036,22 +1048,22 @@
   dist[6] = deltaDist[1];
   dist[7] = deltaDist[3];
 #if MORESTATES
-  rate[8] =
-      get_coeff_cost_def(absLevel[0], coeff_ctx[4], diag_ctx, txb_costs, 0);
-  rate[9] =
-      get_coeff_cost_def(absLevel[2], coeff_ctx[4], diag_ctx, txb_costs, 0);
-  rate[10] =
-      get_coeff_cost_def(absLevel[0], coeff_ctx[5], diag_ctx, txb_costs, 0);
-  rate[11] =
-      get_coeff_cost_def(absLevel[2], coeff_ctx[5], diag_ctx, txb_costs, 0);
-  rate[12] =
-      get_coeff_cost_def(absLevel[1], coeff_ctx[6], diag_ctx, txb_costs, 1);
-  rate[13] =
-      get_coeff_cost_def(absLevel[3], coeff_ctx[6], diag_ctx, txb_costs, 1);
-  rate[14] =
-      get_coeff_cost_def(absLevel[1], coeff_ctx[7], diag_ctx, txb_costs, 1);
-  rate[15] =
-      get_coeff_cost_def(absLevel[3], coeff_ctx[7], diag_ctx, txb_costs, 1);
+  rate[8] = get_coeff_cost_def(absLevel[0], coeff_ctx[4], diag_ctx, plane,
+                               txb_costs, 0, t_sign, sign);
+  rate[9] = get_coeff_cost_def(absLevel[2], coeff_ctx[4], diag_ctx, plane,
+                               txb_costs, 0, t_sign, sign);
+  rate[10] = get_coeff_cost_def(absLevel[0], coeff_ctx[5], diag_ctx, plane,
+                                txb_costs, 0, t_sign, sign);
+  rate[11] = get_coeff_cost_def(absLevel[2], coeff_ctx[5], diag_ctx, plane,
+                                txb_costs, 0, t_sign, sign);
+  rate[12] = get_coeff_cost_def(absLevel[1], coeff_ctx[6], diag_ctx, plane,
+                                txb_costs, 1, t_sign, sign);
+  rate[13] = get_coeff_cost_def(absLevel[3], coeff_ctx[6], diag_ctx, plane,
+                                txb_costs, 1, t_sign, sign);
+  rate[14] = get_coeff_cost_def(absLevel[1], coeff_ctx[7], diag_ctx, plane,
+                                txb_costs, 1, t_sign, sign);
+  rate[15] = get_coeff_cost_def(absLevel[3], coeff_ctx[7], diag_ctx, plane,
+                                txb_costs, 1, t_sign, sign);
   dist[8] = deltaDist[0];
   dist[9] = deltaDist[2];
   dist[10] = deltaDist[0];
@@ -1063,6 +1075,247 @@
 #endif
 }
 
+void av1_get_rate_dist_def_chroma_c(const struct LV_MAP_COEFF_COST *txb_costs,
+                                    const struct prequant_t *pq,
+                                    const uint8_t coeff_ctx[TOTALSTATES + 4],
+                                    int diag_ctx, int plane, int t_sign,
+                                    int sign, int32_t rate_zero[TOTALSTATES],
+                                    int32_t rate[2 * TOTALSTATES],
+                                    int64_t dist[2 * TOTALSTATES]) {
+  const tran_low_t *absLevel = pq->absLevel;
+  const int64_t *deltaDist = pq->deltaDist;
+  for (int i = 0; i < TOTALSTATES; i++) {
+    int base_ctx = diag_ctx + (coeff_ctx[i] & 15);
+    int dq = tcq_quant(i);
+    rate_zero[i] = txb_costs->base_cost_uv[base_ctx][dq][0];
+  }
+  rate[0] = get_coeff_cost_def(absLevel[0], coeff_ctx[0], diag_ctx, plane,
+                               txb_costs, 0, t_sign, sign);
+  rate[1] = get_coeff_cost_def(absLevel[2], coeff_ctx[0], diag_ctx, plane,
+                               txb_costs, 0, t_sign, sign);
+  rate[2] = get_coeff_cost_def(absLevel[0], coeff_ctx[1], diag_ctx, plane,
+                               txb_costs, 0, t_sign, sign);
+  rate[3] = get_coeff_cost_def(absLevel[2], coeff_ctx[1], diag_ctx, plane,
+                               txb_costs, 0, t_sign, sign);
+  rate[4] = get_coeff_cost_def(absLevel[1], coeff_ctx[2], diag_ctx, plane,
+                               txb_costs, 1, t_sign, sign);
+  rate[5] = get_coeff_cost_def(absLevel[3], coeff_ctx[2], diag_ctx, plane,
+                               txb_costs, 1, t_sign, sign);
+  rate[6] = get_coeff_cost_def(absLevel[1], coeff_ctx[3], diag_ctx, plane,
+                               txb_costs, 1, t_sign, sign);
+  rate[7] = get_coeff_cost_def(absLevel[3], coeff_ctx[3], diag_ctx, plane,
+                               txb_costs, 1, t_sign, sign);
+  dist[0] = deltaDist[0];
+  dist[1] = deltaDist[2];
+  dist[2] = deltaDist[0];
+  dist[3] = deltaDist[2];
+  dist[4] = deltaDist[1];
+  dist[5] = deltaDist[3];
+  dist[6] = deltaDist[1];
+  dist[7] = deltaDist[3];
+#if MORESTATES
+  rate[8] = get_coeff_cost_def(absLevel[0], coeff_ctx[4], diag_ctx, plane,
+                               txb_costs, 0, t_sign, sign);
+  rate[9] = get_coeff_cost_def(absLevel[2], coeff_ctx[4], diag_ctx, plane,
+                               txb_costs, 0, t_sign, sign);
+  rate[10] = get_coeff_cost_def(absLevel[0], coeff_ctx[5], diag_ctx, plane,
+                                txb_costs, 0, t_sign, sign);
+  rate[11] = get_coeff_cost_def(absLevel[2], coeff_ctx[5], diag_ctx, plane,
+                                txb_costs, 0, t_sign, sign);
+  rate[12] = get_coeff_cost_def(absLevel[1], coeff_ctx[6], diag_ctx, plane,
+                                txb_costs, 1, t_sign, sign);
+  rate[13] = get_coeff_cost_def(absLevel[3], coeff_ctx[6], diag_ctx, plane,
+                                txb_costs, 1, t_sign, sign);
+  rate[14] = get_coeff_cost_def(absLevel[1], coeff_ctx[7], diag_ctx, plane,
+                                txb_costs, 1, t_sign, sign);
+  rate[15] = get_coeff_cost_def(absLevel[3], coeff_ctx[7], diag_ctx, plane,
+                                txb_costs, 1, t_sign, sign);
+  dist[8] = deltaDist[0];
+  dist[9] = deltaDist[2];
+  dist[10] = deltaDist[0];
+  dist[11] = deltaDist[2];
+  dist[12] = deltaDist[1];
+  dist[13] = deltaDist[3];
+  dist[14] = deltaDist[1];
+  dist[15] = deltaDist[3];
+#endif
+#if 0
+  static int n = 0;
+  int s_rate_zero[TOTALSTATES];
+  int s_rate[2 * TOTALSTATES];
+  int64_t s_dist[2 * TOTALSTATES];
+  av1_get_rate_dist_def_chroma_avx2(txb_costs, pq, coeff_ctx, diag_ctx, plane, t_sign, sign,
+                                    s_rate_zero, s_rate, s_dist);
+  int ok = 1;
+  for (int i = 0; i < TOTALSTATES; i++) {
+    int chk = s_rate_zero[i] == rate_zero[i];
+    if (!chk) {
+      printf("CHK i %d rate_zero %d %d\n", i, s_rate_zero[i], rate_zero[i]);
+    }
+    ok &= chk;
+  }
+  for (int i = 0; i < 2*TOTALSTATES; i++) {
+    int chk = s_rate[i] == rate[i];
+    if (!chk) {
+      printf("CHK i %d rate %d %d\n", i, s_rate[i], rate[i]);
+    }
+    ok &= chk;
+  }
+  for (int i = 0; i < 2*TOTALSTATES; i++) {
+    int chk = s_dist[i] == dist[i];
+    if (!chk) {
+      printf("CHK i %d dist %ld %ld\n", i, s_dist[i], dist[i]);
+    }
+    ok &= chk;
+  }
+  n++;
+  if (!ok) {
+    printf("plane %d t_sign %d sign %d\n", plane, t_sign, sign);
+    exit(1);
+  }
+#endif
+}
+
+void av1_get_rate_dist_lf_c(const struct LV_MAP_COEFF_COST *txb_costs,
+                            const struct prequant_t *pq,
+                            const uint8_t coeff_ctx[TOTALSTATES + 4],
+                            int diag_ctx, int dc_sign_ctx, int32_t *tmp_sign,
+                            int bwl, TX_CLASS tx_class, int plane, int blk_pos,
+                            int coeff_sign, int32_t rate_zero[TOTALSTATES],
+                            int32_t rate[2 * TOTALSTATES],
+                            int64_t dist[2 * TOTALSTATES]) {
+  const tran_low_t *absLevel = pq->absLevel;
+  const int64_t *deltaDist = pq->deltaDist;
+  uint8_t base_ctx[TOTALSTATES];
+  uint8_t mid_ctx[TOTALSTATES];
+
+  for (int i = 0; i < TOTALSTATES; i++) {
+    base_ctx[i] = (coeff_ctx[i] & 15) + diag_ctx;
+    mid_ctx[i] = coeff_ctx[i] >> 4;
+  }
+
+  // calculate RDcost
+  for (int i = 0; i < TOTALSTATES; i++) {
+    int dq = tcq_quant(i);
+    const int(*base_lf_cost_ptr)[DQ_CTXS][LF_BASE_SYMBOLS * 2] =
+        plane > 0 ? txb_costs->base_lf_cost_uv : txb_costs->base_lf_cost;
+    rate_zero[i] = base_lf_cost_ptr[base_ctx[i]][dq][0];
+    if (0)
+      printf(
+          "save_rate_zero[%d] = base_lf_cost_ptr[%d][%d][%d] = %04x %04x "
+          "coeff_ctx %02X diag_ctx %d\n",
+          i, base_ctx[i], dq, 0, base_lf_cost_ptr[base_ctx[i]][dq][0],
+          rate_zero[i], coeff_ctx[i], diag_ctx);
+  }
+
+  rate[0] = get_coeff_cost(blk_pos, absLevel[0], coeff_sign, base_ctx[0],
+                           mid_ctx[0], dc_sign_ctx, txb_costs, bwl, tx_class,
+                           tmp_sign, plane, 1, 0);
+  rate[1] = get_coeff_cost(blk_pos, absLevel[2], coeff_sign, base_ctx[0],
+                           mid_ctx[0], dc_sign_ctx, txb_costs, bwl, tx_class,
+                           tmp_sign, plane, 1, 0);
+  rate[2] = get_coeff_cost(blk_pos, absLevel[0], coeff_sign, base_ctx[1],
+                           mid_ctx[1], dc_sign_ctx, txb_costs, bwl, tx_class,
+                           tmp_sign, plane, 1, 0);
+  rate[3] = get_coeff_cost(blk_pos, absLevel[2], coeff_sign, base_ctx[1],
+                           mid_ctx[1], dc_sign_ctx, txb_costs, bwl, tx_class,
+                           tmp_sign, plane, 1, 0);
+  rate[4] = get_coeff_cost(blk_pos, absLevel[1], coeff_sign, base_ctx[2],
+                           mid_ctx[2], dc_sign_ctx, txb_costs, bwl, tx_class,
+                           tmp_sign, plane, 1, 1);
+  rate[5] = get_coeff_cost(blk_pos, absLevel[3], coeff_sign, base_ctx[2],
+                           mid_ctx[2], dc_sign_ctx, txb_costs, bwl, tx_class,
+                           tmp_sign, plane, 1, 1);
+  rate[6] = get_coeff_cost(blk_pos, absLevel[1], coeff_sign, base_ctx[3],
+                           mid_ctx[3], dc_sign_ctx, txb_costs, bwl, tx_class,
+                           tmp_sign, plane, 1, 1);
+  rate[7] = get_coeff_cost(blk_pos, absLevel[3], coeff_sign, base_ctx[3],
+                           mid_ctx[3], dc_sign_ctx, txb_costs, bwl, tx_class,
+                           tmp_sign, plane, 1, 1);
+#if MORESTATES
+  rate[8] = get_coeff_cost(blk_pos, absLevel[0], coeff_sign, base_ctx[4],
+                           mid_ctx[4], dc_sign_ctx, txb_costs, bwl, tx_class,
+                           tmp_sign, plane, 1, 0);
+  rate[9] = get_coeff_cost(blk_pos, absLevel[2], coeff_sign, base_ctx[4],
+                           mid_ctx[4], dc_sign_ctx, txb_costs, bwl, tx_class,
+                           tmp_sign, plane, 1, 0);
+  rate[10] = get_coeff_cost(blk_pos, absLevel[0], coeff_sign, base_ctx[5],
+                            mid_ctx[5], dc_sign_ctx, txb_costs, bwl, tx_class,
+                            tmp_sign, plane, 1, 0);
+  rate[11] = get_coeff_cost(blk_pos, absLevel[2], coeff_sign, base_ctx[5],
+                            mid_ctx[5], dc_sign_ctx, txb_costs, bwl, tx_class,
+                            tmp_sign, plane, 1, 0);
+  rate[12] = get_coeff_cost(blk_pos, absLevel[1], coeff_sign, base_ctx[6],
+                            mid_ctx[6], dc_sign_ctx, txb_costs, bwl, tx_class,
+                            tmp_sign, plane, 1, 1);
+  rate[13] = get_coeff_cost(blk_pos, absLevel[3], coeff_sign, base_ctx[6],
+                            mid_ctx[6], dc_sign_ctx, txb_costs, bwl, tx_class,
+                            tmp_sign, plane, 1, 1);
+  rate[14] = get_coeff_cost(blk_pos, absLevel[1], coeff_sign, base_ctx[7],
+                            mid_ctx[7], dc_sign_ctx, txb_costs, bwl, tx_class,
+                            tmp_sign, plane, 1, 1);
+  rate[15] = get_coeff_cost(blk_pos, absLevel[3], coeff_sign, base_ctx[7],
+                            mid_ctx[7], dc_sign_ctx, txb_costs, bwl, tx_class,
+                            tmp_sign, plane, 1, 1);
+#endif
+  dist[0] = deltaDist[0];
+  dist[1] = deltaDist[2];
+  dist[2] = deltaDist[0];
+  dist[3] = deltaDist[2];
+  dist[4] = deltaDist[1];
+  dist[5] = deltaDist[3];
+  dist[6] = deltaDist[1];
+  dist[7] = deltaDist[3];
+#if MORESTATES
+  dist[8] = deltaDist[0];
+  dist[9] = deltaDist[2];
+  dist[10] = deltaDist[0];
+  dist[11] = deltaDist[2];
+  dist[12] = deltaDist[1];
+  dist[13] = deltaDist[3];
+  dist[14] = deltaDist[1];
+  dist[15] = deltaDist[3];
+#endif
+#if 0
+  static int n = 0;
+  int t_sign = tmp_sign[blk_pos];
+  int sign = coeff_sign;
+  int s_rate_zero[TOTALSTATES];
+  int s_rate[2 * TOTALSTATES];
+  int64_t s_dist[2 * TOTALSTATES];
+  av1_get_rate_dist_lf_avx2(txb_costs, pq, coeff_ctx, diag_ctx, dc_sign_ctx,
+                            tmp_sign, bwl, tx_class, plane, blk_pos, coeff_sign,
+                            s_rate_zero, s_rate, s_dist);
+  int ok = 1;
+  for (int i = 0; i < TOTALSTATES; i++) {
+    int chk = s_rate_zero[i] == rate_zero[i];
+    if (!chk) {
+      printf("CHK i %d rate_zero %d %d\n", i, s_rate_zero[i], rate_zero[i]);
+    }
+    ok &= chk;
+  }
+  for (int i = 0; i < 2*TOTALSTATES; i++) {
+    int chk = s_rate[i] == rate[i];
+    if (!chk || n == 0) {
+      printf("CHK i %d rate %d %d\n", i, s_rate[i], rate[i]);
+    }
+    ok &= chk;
+  }
+  for (int i = 0; i < 2*TOTALSTATES; i++) {
+    int chk = s_dist[i] == dist[i];
+    if (!chk) {
+      printf("CHK i %d dist %ld %ld\n", i, s_dist[i], dist[i]);
+    }
+    ok &= chk;
+  }
+  n++;
+  if (!ok) {
+    printf("plane %d t_sign %d sign %d\n", plane, t_sign, sign);
+    exit(1);
+  }
+#endif
+}
+
 void av1_calc_diag_ctx_c(int scan_hi, int scan_lo, int bwl,
                          const uint8_t *prev_levels, const int16_t *scan,
                          uint8_t *ctx) {
@@ -1114,6 +1367,7 @@
   }
 }
 
+// Handle trellis default region for Luma, TX_CLASS_2D blocks.
 void trellis_loop_diagonal(
     int scan_hi, int scan_lo, int plane, TX_SIZE tx_size, TX_TYPE tx_type,
     int32_t *tmp_sign, int sharpness, tcq_levels_t *tcq_lev,
@@ -1128,12 +1382,7 @@
   const int shift = av1_get_tx_scale(tx_size);
   const TX_CLASS tx_class = tx_type_to_class[get_primary_tx_type(tx_type)];
   const int pos0 = scan[scan_hi];
-  const int row0 = pos0 >> bwl;
-  const int col0 = pos0 - (row0 << bwl);
-  const int diag_ctx = (plane != 0 || row0 + col0 < 6) ? 0
-                       : row0 + col0 < 8               ? 5
-                                                       : 10;
-
+  const int diag_ctx = get_nz_map_ctx_from_stats(0, pos0, bwl, TX_CLASS_2D, 0);
   assert(plane == 0);
   assert(tx_class == TX_CLASS_2D);
 
@@ -1164,13 +1413,13 @@
     int64_t dist[2 * TOTALSTATES];
 
     // calculate rate distortion
-    uint8_t coeff_ctx[2 * TOTALSTATES];  // extra alloc for simd (loadu_si64)
+    uint8_t coeff_ctx[TOTALSTATES + 4];  // extra +4 alloc to allow SIMD load.
     for (int i = 0; i < TOTALSTATES; i++) {
       coeff_ctx[i] = tcq_ctx[i].ctx[scan_pos - scan_lo];
     }
 
-    av1_get_rate_dist_def(txb_costs, &pqData, coeff_ctx, diag_ctx, plane,
-                          rate_zero, rate, dist);
+    av1_get_rate_dist_def_luma(txb_costs, &pqData, coeff_ctx, diag_ctx,
+                               rate_zero, rate, dist);
 
     av1_decide_states(prev_decision, dist, rate, rate_zero, &pqData, limits,
                       rdmult, decision);
@@ -1223,6 +1472,7 @@
                          scan_hi, scan_lo, tcq_ctx);
 }
 
+// Handle trellis Low-freq (LF) region for Luma, TX_CLASS_2D blocks.
 void trellis_loop_lf(int first_scan_pos, int scan_hi, int scan_lo, int plane,
                      TX_SIZE tx_size, TX_TYPE tx_type, int32_t *tmp_sign,
                      int sharpness, tcq_levels_t *tcq_lev,
@@ -1237,11 +1487,14 @@
   const int height = get_txb_high(tx_size);
   const int shift = av1_get_tx_scale(tx_size);
   const TX_CLASS tx_class = tx_type_to_class[get_primary_tx_type(tx_type)];
+  assert(plane == 0);
+  assert(tx_class == TX_CLASS_2D);
 
   for (int scan_pos = scan_hi; scan_pos >= scan_lo; scan_pos--) {
     tcq_levels_swap(tcq_lev);
 
     int blk_pos = scan[scan_pos];
+
     tcq_node_t *decision = trellis[scan_pos];
     tcq_node_t *prd = trellis[scan_pos + 1];
 
@@ -1254,113 +1507,32 @@
     init_tcq_decision(decision);
     const int coeff_sign = tcoeff[blk_pos] < 0;
     const int limits = 1;  // Always in LF region.
-    int rate_zero[TOTALSTATES];
-    int rate[2 * TOTALSTATES];
-    int64_t dist[2 * TOTALSTATES];
 
-    // calculate rate distortion
-    uint8_t coeff_ctx[TOTALSTATES];
-    uint8_t mid_ctx[TOTALSTATES];
+    // calculate contexts
+    int diag_ctx = get_nz_map_ctx_from_stats_lf(0, blk_pos, bwl, tx_class);
+    uint8_t coeff_ctx[TOTALSTATES + 4];
     for (int i = 0; i < TOTALSTATES; i++) {
       uint8_t *prev_lev = tcq_levels_prev(tcq_lev, i);
       coeff_ctx[i] = get_lower_levels_lf_ctx(prev_lev, blk_pos, bwl, tx_class);
       int br_ctx = plane
                        ? get_br_lf_ctx_chroma(prev_lev, blk_pos, bwl, tx_class)
                        : get_br_lf_ctx(prev_lev, blk_pos, bwl, tx_class);
-      mid_ctx[i] = br_ctx;
+      coeff_ctx[i] -= diag_ctx;
+      coeff_ctx[i] += br_ctx << 4;
     }
 
-    // calculate RDcost
-    for (int i = 0; i < TOTALSTATES; i++) {
-      int dq = tcq_quant(i);
-#if CONFIG_LCCHROMA
-      const int(*base_lf_cost_ptr)[DQ_CTXS][LF_BASE_SYMBOLS * 2] =
-          plane > 0 ? txb_costs->base_lf_cost_uv : txb_costs->base_lf_cost;
-      rate_zero[i] = base_lf_cost_ptr[coeff_ctx[i]][dq][0];
-#else
-      rate_zero[i] = txb_costs->base_lf_cost[coeff_ctx[i]][0];
-#endif
-    }
+    // calculate rate distortion
+    int rate_zero[TOTALSTATES];
+    int rate[2 * TOTALSTATES];
+    int64_t dist[2 * TOTALSTATES];
+    av1_get_rate_dist_lf(txb_costs, &pqData, coeff_ctx, diag_ctx,
+                         txb_ctx->dc_sign_ctx, tmp_sign, bwl, tx_class, plane,
+                         blk_pos, coeff_sign, rate_zero, rate, dist);
 
-    int dc_sign_ctx = txb_ctx->dc_sign_ctx;
-    rate[0] = get_coeff_cost(blk_pos, pqData.absLevel[0], coeff_sign,
-                             coeff_ctx[0], mid_ctx[0], dc_sign_ctx, txb_costs,
-                             bwl, tx_class, tmp_sign, plane, 1, 0);
-    rate[1] = get_coeff_cost(blk_pos, pqData.absLevel[2], coeff_sign,
-                             coeff_ctx[0], mid_ctx[0], dc_sign_ctx, txb_costs,
-                             bwl, tx_class, tmp_sign, plane, 1, 0);
-    rate[2] = get_coeff_cost(blk_pos, pqData.absLevel[0], coeff_sign,
-                             coeff_ctx[1], mid_ctx[1], dc_sign_ctx, txb_costs,
-                             bwl, tx_class, tmp_sign, plane, 1, 0);
-    rate[3] = get_coeff_cost(blk_pos, pqData.absLevel[2], coeff_sign,
-                             coeff_ctx[1], mid_ctx[1], dc_sign_ctx, txb_costs,
-                             bwl, tx_class, tmp_sign, plane, 1, 0);
-    rate[4] = get_coeff_cost(blk_pos, pqData.absLevel[1], coeff_sign,
-                             coeff_ctx[2], mid_ctx[2], dc_sign_ctx, txb_costs,
-                             bwl, tx_class, tmp_sign, plane, 1, 1);
-    rate[5] = get_coeff_cost(blk_pos, pqData.absLevel[3], coeff_sign,
-                             coeff_ctx[2], mid_ctx[2], dc_sign_ctx, txb_costs,
-                             bwl, tx_class, tmp_sign, plane, 1, 1);
-    rate[6] = get_coeff_cost(blk_pos, pqData.absLevel[1], coeff_sign,
-                             coeff_ctx[3], mid_ctx[3], dc_sign_ctx, txb_costs,
-                             bwl, tx_class, tmp_sign, plane, 1, 1);
-    rate[7] = get_coeff_cost(blk_pos, pqData.absLevel[3], coeff_sign,
-                             coeff_ctx[3], mid_ctx[3], dc_sign_ctx, txb_costs,
-                             bwl, tx_class, tmp_sign, plane, 1, 1);
-#if MORESTATES
-    rate[8] = get_coeff_cost(blk_pos, pqData.absLevel[0], coeff_sign,
-                             coeff_ctx[4], mid_ctx[4], dc_sign_ctx, txb_costs,
-                             bwl, tx_class, tmp_sign, plane, 1, 0);
-    rate[9] = get_coeff_cost(blk_pos, pqData.absLevel[2], coeff_sign,
-                             coeff_ctx[4], mid_ctx[4], dc_sign_ctx, txb_costs,
-                             bwl, tx_class, tmp_sign, plane, 1, 0);
-    rate[10] = get_coeff_cost(blk_pos, pqData.absLevel[0], coeff_sign,
-                              coeff_ctx[5], mid_ctx[5], dc_sign_ctx, txb_costs,
-                              bwl, tx_class, tmp_sign, plane, 1, 0);
-    rate[11] = get_coeff_cost(blk_pos, pqData.absLevel[2], coeff_sign,
-                              coeff_ctx[5], mid_ctx[5], dc_sign_ctx, txb_costs,
-                              bwl, tx_class, tmp_sign, plane, 1, 0);
-    rate[12] = get_coeff_cost(blk_pos, pqData.absLevel[1], coeff_sign,
-                              coeff_ctx[6], mid_ctx[6], dc_sign_ctx, txb_costs,
-                              bwl, tx_class, tmp_sign, plane, 1, 1);
-    rate[13] = get_coeff_cost(blk_pos, pqData.absLevel[3], coeff_sign,
-                              coeff_ctx[6], mid_ctx[6], dc_sign_ctx, txb_costs,
-                              bwl, tx_class, tmp_sign, plane, 1, 1);
-    rate[14] = get_coeff_cost(blk_pos, pqData.absLevel[1], coeff_sign,
-                              coeff_ctx[7], mid_ctx[7], dc_sign_ctx, txb_costs,
-                              bwl, tx_class, tmp_sign, plane, 1, 1);
-    rate[15] = get_coeff_cost(blk_pos, pqData.absLevel[3], coeff_sign,
-                              coeff_ctx[7], mid_ctx[7], dc_sign_ctx, txb_costs,
-                              bwl, tx_class, tmp_sign, plane, 1, 1);
-#endif
-    dist[0] = pqData.deltaDist[0];
-    dist[1] = pqData.deltaDist[2];
-    dist[2] = pqData.deltaDist[0];
-    dist[3] = pqData.deltaDist[2];
-    dist[4] = pqData.deltaDist[1];
-    dist[5] = pqData.deltaDist[3];
-    dist[6] = pqData.deltaDist[1];
-    dist[7] = pqData.deltaDist[3];
-#if MORESTATES
-    dist[8] = pqData.deltaDist[0];
-    dist[9] = pqData.deltaDist[2];
-    dist[10] = pqData.deltaDist[0];
-    dist[11] = pqData.deltaDist[2];
-    dist[12] = pqData.deltaDist[1];
-    dist[13] = pqData.deltaDist[3];
-    dist[14] = pqData.deltaDist[1];
-    dist[15] = pqData.deltaDist[3];
-#endif
-
-    // todo: Q0 can skip the sig_flag or skip some another flag. This is not
-    // included in the calculation of RDcost now.
     av1_decide_states(prd, dist, rate, rate_zero, &pqData, limits, rdmult,
                       decision);
 
-    // assume current state is 0, current coeff is new eob.
-    // input: scan_pos,pqData[0],pqData[2], decison[0] and decision[2]
-    //  update eob if better use current position as eob
-
+    // update eob if better
     if (sharpness == 0) {
       int new_eob_rate = block_eob_rate[scan_pos];
       int new_eob_ctx = get_lower_levels_ctx_eob(bwl, height, scan_pos);
@@ -1435,7 +1607,12 @@
     }
 
     int blk_pos = scan[scan_pos];
+    int row = blk_pos >> bwl;
+    int col = blk_pos - (row << bwl);
+    int limits = get_lf_limits(row, col, tx_class, plane);
+
     tcq_node_t *decision = trellis[scan_pos];
+    tcq_node_t *prd = trellis[scan_pos + 1];
 
     prequant_t pqData;
     int tempdqv = get_dqv(dequant, scan[scan_pos], iqmatrix);
@@ -1446,23 +1623,30 @@
     init_tcq_decision(decision);
     const int coeff_sign = tcoeff[blk_pos] < 0;
 
-    const int row = blk_pos >> bwl;
-    const int col = blk_pos - (row << bwl);
-    int limits = get_lf_limits(row, col, tx_class, plane);
-
-    // calculate rate distortion
-    uint8_t coeff_ctx[TOTALSTATES];
+    // calculate contexts
+    int diag_ctx =
+        (limits && plane == 0)
+            ? get_nz_map_ctx_from_stats_lf(0, blk_pos, bwl, tx_class)
+        : plane == 0 ? get_nz_map_ctx_from_stats(0, blk_pos, bwl, tx_class, 0)
+        : limits
+            ? get_nz_map_ctx_from_stats_lf_chroma(0, tx_class, plane)
+            : get_nz_map_ctx_from_stats_chroma(0, blk_pos, tx_class, plane);
+    uint8_t coeff_ctx[TOTALSTATES + 4];
     if (limits) {
       for (int i = 0; i < TOTALSTATES; i++) {
-        coeff_ctx[i] = plane
+        int base_ctx = plane
                            ? get_lower_levels_lf_ctx_chroma(
                                  prev_levels[i], blk_pos, bwl, tx_class, plane)
                            : get_lower_levels_lf_ctx(prev_levels[i], blk_pos,
                                                      bwl, tx_class);
+        int br_ctx =
+            plane ? get_br_lf_ctx_chroma(prev_levels[i], blk_pos, bwl, tx_class)
+                  : get_br_lf_ctx(prev_levels[i], blk_pos, bwl, tx_class);
+        coeff_ctx[i] = base_ctx - diag_ctx + (br_ctx << 4);
       }
     } else {
       for (int i = 0; i < TOTALSTATES; i++) {
-        coeff_ctx[i] =
+        int base_ctx =
             plane ? get_lower_levels_ctx_chroma(prev_levels[i], blk_pos, bwl,
                                                 tx_class, plane)
                   : get_lower_levels_ctx(prev_levels[i], blk_pos, bwl, tx_class
@@ -1471,212 +1655,63 @@
                                          plane
 #endif
                     );
-      }
-    }
-    uint8_t mid_ctx[TOTALSTATES];
-    if (limits) {
-      for (int i = 0; i < TOTALSTATES; i++) {
-        int br_ctx =
-            plane ? get_br_lf_ctx_chroma(prev_levels[i], blk_pos, bwl, tx_class)
-                  : get_br_lf_ctx(prev_levels[i], blk_pos, bwl, tx_class);
-        mid_ctx[i] = br_ctx;
-      }
-    } else {
-      for (int i = 0; i < TOTALSTATES; i++) {
         int br_ctx =
             plane ? get_br_ctx_chroma(prev_levels[i], blk_pos, bwl, tx_class)
                   : get_br_ctx(prev_levels[i], blk_pos, bwl, tx_class);
-        mid_ctx[i] = br_ctx;
+        coeff_ctx[i] = base_ctx - diag_ctx + (br_ctx << 4);
       }
     }
 
-    // calculate RDcost
-    int rate_zero[TOTALSTATES];
-    if (limits) {
-      for (int i = 0; i < TOTALSTATES; i++) {
-        int dq = tcq_quant(i);
-#if CONFIG_LCCHROMA
-        const int(*base_lf_cost_ptr)[DQ_CTXS][LF_BASE_SYMBOLS * 2] =
-            plane > 0 ? txb_costs->base_lf_cost_uv : txb_costs->base_lf_cost;
-        rate_zero[i] = base_lf_cost_ptr[coeff_ctx[i]][dq][0];
-#else
-        rate_zero[i] = txb_costs->base_lf_cost[coeff_ctx[i]][0];
-#endif
-      }
-    } else {
-      for (int i = 0; i < TOTALSTATES; i++) {
-        int dq = tcq_quant(i);
-#if CONFIG_LCCHROMA
-        const int(*base_cost_ptr)[DQ_CTXS][8] =
-            plane > 0 ? txb_costs->base_cost_uv : txb_costs->base_cost;
-        rate_zero[i] = base_cost_ptr[coeff_ctx[i]][dq][0];
-#else
-        rate_zero[i] = txb_costs->base_cost[coeff_ctx[i]][0];
-#endif
-      }
-    }
-
-    tcq_node_t *prd = trellis[scan_pos + 1];
-    int rate[2 * TOTALSTATES];
-    int dc_sign_ctx = txb_ctx->dc_sign_ctx;
-    if (limits) {
-      rate[0] = get_coeff_cost(blk_pos, pqData.absLevel[0], coeff_sign,
-                               coeff_ctx[0], mid_ctx[0], dc_sign_ctx, txb_costs,
-                               bwl, tx_class, tmp_sign, plane, 1, 0);
-      rate[1] = get_coeff_cost(blk_pos, pqData.absLevel[2], coeff_sign,
-                               coeff_ctx[0], mid_ctx[0], dc_sign_ctx, txb_costs,
-                               bwl, tx_class, tmp_sign, plane, 1, 0);
-      rate[2] = get_coeff_cost(blk_pos, pqData.absLevel[0], coeff_sign,
-                               coeff_ctx[1], mid_ctx[1], dc_sign_ctx, txb_costs,
-                               bwl, tx_class, tmp_sign, plane, 1, 0);
-      rate[3] = get_coeff_cost(blk_pos, pqData.absLevel[2], coeff_sign,
-                               coeff_ctx[1], mid_ctx[1], dc_sign_ctx, txb_costs,
-                               bwl, tx_class, tmp_sign, plane, 1, 0);
-      rate[4] = get_coeff_cost(blk_pos, pqData.absLevel[1], coeff_sign,
-                               coeff_ctx[2], mid_ctx[2], dc_sign_ctx, txb_costs,
-                               bwl, tx_class, tmp_sign, plane, 1, 1);
-      rate[5] = get_coeff_cost(blk_pos, pqData.absLevel[3], coeff_sign,
-                               coeff_ctx[2], mid_ctx[2], dc_sign_ctx, txb_costs,
-                               bwl, tx_class, tmp_sign, plane, 1, 1);
-      rate[6] = get_coeff_cost(blk_pos, pqData.absLevel[1], coeff_sign,
-                               coeff_ctx[3], mid_ctx[3], dc_sign_ctx, txb_costs,
-                               bwl, tx_class, tmp_sign, plane, 1, 1);
-      rate[7] = get_coeff_cost(blk_pos, pqData.absLevel[3], coeff_sign,
-                               coeff_ctx[3], mid_ctx[3], dc_sign_ctx, txb_costs,
-                               bwl, tx_class, tmp_sign, plane, 1, 1);
-#if MORESTATES
-      rate[8] = get_coeff_cost(blk_pos, pqData.absLevel[0], coeff_sign,
-                               coeff_ctx[4], mid_ctx[4], dc_sign_ctx, txb_costs,
-                               bwl, tx_class, tmp_sign, plane, 1, 0);
-      rate[9] = get_coeff_cost(blk_pos, pqData.absLevel[2], coeff_sign,
-                               coeff_ctx[4], mid_ctx[4], dc_sign_ctx, txb_costs,
-                               bwl, tx_class, tmp_sign, plane, 1, 0);
-      rate[10] = get_coeff_cost(
-          blk_pos, pqData.absLevel[0], coeff_sign, coeff_ctx[5], mid_ctx[5],
-          dc_sign_ctx, txb_costs, bwl, tx_class, tmp_sign, plane, 1, 0);
-      rate[11] = get_coeff_cost(
-          blk_pos, pqData.absLevel[2], coeff_sign, coeff_ctx[5], mid_ctx[5],
-          dc_sign_ctx, txb_costs, bwl, tx_class, tmp_sign, plane, 1, 0);
-      rate[12] = get_coeff_cost(
-          blk_pos, pqData.absLevel[1], coeff_sign, coeff_ctx[6], mid_ctx[6],
-          dc_sign_ctx, txb_costs, bwl, tx_class, tmp_sign, plane, 1, 1);
-      rate[13] = get_coeff_cost(
-          blk_pos, pqData.absLevel[3], coeff_sign, coeff_ctx[6], mid_ctx[6],
-          dc_sign_ctx, txb_costs, bwl, tx_class, tmp_sign, plane, 1, 1);
-      rate[14] = get_coeff_cost(
-          blk_pos, pqData.absLevel[1], coeff_sign, coeff_ctx[7], mid_ctx[7],
-          dc_sign_ctx, txb_costs, bwl, tx_class, tmp_sign, plane, 1, 1);
-      rate[15] = get_coeff_cost(
-          blk_pos, pqData.absLevel[3], coeff_sign, coeff_ctx[7], mid_ctx[7],
-          dc_sign_ctx, txb_costs, bwl, tx_class, tmp_sign, plane, 1, 1);
-#endif
-    } else {
-      rate[0] = get_coeff_cost(blk_pos, pqData.absLevel[0], coeff_sign,
-                               coeff_ctx[0], mid_ctx[0], dc_sign_ctx, txb_costs,
-                               bwl, tx_class, tmp_sign, plane, 0, 0);
-      rate[1] = get_coeff_cost(blk_pos, pqData.absLevel[2], coeff_sign,
-                               coeff_ctx[0], mid_ctx[0], dc_sign_ctx, txb_costs,
-                               bwl, tx_class, tmp_sign, plane, 0, 0);
-      rate[2] = get_coeff_cost(blk_pos, pqData.absLevel[0], coeff_sign,
-                               coeff_ctx[1], mid_ctx[1], dc_sign_ctx, txb_costs,
-                               bwl, tx_class, tmp_sign, plane, 0, 0);
-      rate[3] = get_coeff_cost(blk_pos, pqData.absLevel[2], coeff_sign,
-                               coeff_ctx[1], mid_ctx[1], dc_sign_ctx, txb_costs,
-                               bwl, tx_class, tmp_sign, plane, 0, 0);
-      rate[4] = get_coeff_cost(blk_pos, pqData.absLevel[1], coeff_sign,
-                               coeff_ctx[2], mid_ctx[2], dc_sign_ctx, txb_costs,
-                               bwl, tx_class, tmp_sign, plane, 0, 1);
-      rate[5] = get_coeff_cost(blk_pos, pqData.absLevel[3], coeff_sign,
-                               coeff_ctx[2], mid_ctx[2], dc_sign_ctx, txb_costs,
-                               bwl, tx_class, tmp_sign, plane, 0, 1);
-      rate[6] = get_coeff_cost(blk_pos, pqData.absLevel[1], coeff_sign,
-                               coeff_ctx[3], mid_ctx[3], dc_sign_ctx, txb_costs,
-                               bwl, tx_class, tmp_sign, plane, 0, 1);
-      rate[7] = get_coeff_cost(blk_pos, pqData.absLevel[3], coeff_sign,
-                               coeff_ctx[3], mid_ctx[3], dc_sign_ctx, txb_costs,
-                               bwl, tx_class, tmp_sign, plane, 0, 1);
-#if MORESTATES
-      rate[8] = get_coeff_cost(blk_pos, pqData.absLevel[0], coeff_sign,
-                               coeff_ctx[4], mid_ctx[4], dc_sign_ctx, txb_costs,
-                               bwl, tx_class, tmp_sign, plane, 0, 0);
-      rate[9] = get_coeff_cost(blk_pos, pqData.absLevel[2], coeff_sign,
-                               coeff_ctx[4], mid_ctx[4], dc_sign_ctx, txb_costs,
-                               bwl, tx_class, tmp_sign, plane, 0, 0);
-      rate[10] = get_coeff_cost(
-          blk_pos, pqData.absLevel[0], coeff_sign, coeff_ctx[5], mid_ctx[5],
-          dc_sign_ctx, txb_costs, bwl, tx_class, tmp_sign, plane, 0, 0);
-      rate[11] = get_coeff_cost(
-          blk_pos, pqData.absLevel[2], coeff_sign, coeff_ctx[5], mid_ctx[5],
-          dc_sign_ctx, txb_costs, bwl, tx_class, tmp_sign, plane, 0, 0);
-      rate[12] = get_coeff_cost(
-          blk_pos, pqData.absLevel[1], coeff_sign, coeff_ctx[6], mid_ctx[6],
-          dc_sign_ctx, txb_costs, bwl, tx_class, tmp_sign, plane, 0, 1);
-      rate[13] = get_coeff_cost(
-          blk_pos, pqData.absLevel[3], coeff_sign, coeff_ctx[6], mid_ctx[6],
-          dc_sign_ctx, txb_costs, bwl, tx_class, tmp_sign, plane, 0, 1);
-      rate[14] = get_coeff_cost(
-          blk_pos, pqData.absLevel[1], coeff_sign, coeff_ctx[7], mid_ctx[7],
-          dc_sign_ctx, txb_costs, bwl, tx_class, tmp_sign, plane, 0, 1);
-      rate[15] = get_coeff_cost(
-          blk_pos, pqData.absLevel[3], coeff_sign, coeff_ctx[7], mid_ctx[7],
-          dc_sign_ctx, txb_costs, bwl, tx_class, tmp_sign, plane, 0, 1);
-#endif
-    }
+    // calculate rate distortion
+    int32_t rate_zero[TOTALSTATES];
+    int32_t rate[2 * TOTALSTATES];
     int64_t dist[2 * TOTALSTATES];
-    {
-      dist[0] = pqData.deltaDist[0];
-      dist[1] = pqData.deltaDist[2];
-      dist[2] = pqData.deltaDist[0];
-      dist[3] = pqData.deltaDist[2];
-      dist[4] = pqData.deltaDist[1];
-      dist[5] = pqData.deltaDist[3];
-      dist[6] = pqData.deltaDist[1];
-      dist[7] = pqData.deltaDist[3];
-#if MORESTATES
-      dist[8] = pqData.deltaDist[0];
-      dist[9] = pqData.deltaDist[2];
-      dist[10] = pqData.deltaDist[0];
-      dist[11] = pqData.deltaDist[2];
-      dist[12] = pqData.deltaDist[1];
-      dist[13] = pqData.deltaDist[3];
-      dist[14] = pqData.deltaDist[1];
-      dist[15] = pqData.deltaDist[3];
-#endif
-      av1_decide_states(prd, dist, rate, rate_zero, &pqData, limits, rdmult,
-                        decision);
+    if (limits) {
+      av1_get_rate_dist_lf(txb_costs, &pqData, coeff_ctx, diag_ctx,
+                           txb_ctx->dc_sign_ctx, tmp_sign, bwl, tx_class, plane,
+                           blk_pos, coeff_sign, rate_zero, rate, dist);
+    } else if (plane == 0) {
+      av1_get_rate_dist_def_luma(txb_costs, &pqData, coeff_ctx, diag_ctx,
+                                 rate_zero, rate, dist);
+    } else {
+      av1_get_rate_dist_def_chroma(txb_costs, &pqData, coeff_ctx, diag_ctx,
+                                   plane, tmp_sign[blk_pos], coeff_sign,
+                                   rate_zero, rate, dist);
+    }
 
-      if (sharpness == 0) {
-        int new_eob_rate = block_eob_rate[scan_pos];
-        int new_eob_ctx = get_lower_levels_ctx_eob(bwl, height, scan_pos);
-        int rate_Q0_a =
-            get_coeff_cost_eob(blk_pos, pqData.absLevel[0],
-                               (qcoeff[blk_pos] < 0), new_eob_ctx,
-                               txb_ctx->dc_sign_ctx, txb_costs, bwl, tx_class
+    av1_decide_states(prd, dist, rate, rate_zero, &pqData, limits, rdmult,
+                      decision);
+
+    if (sharpness == 0) {
+      int new_eob_rate = block_eob_rate[scan_pos];
+      int new_eob_ctx = get_lower_levels_ctx_eob(bwl, height, scan_pos);
+      int rate_Q0_a =
+          get_coeff_cost_eob(blk_pos, pqData.absLevel[0], (qcoeff[blk_pos] < 0),
+                             new_eob_ctx, txb_ctx->dc_sign_ctx, txb_costs, bwl,
+                             tx_class
 #if CONFIG_CONTEXT_DERIVATION
-                               ,
-                               tmp_sign
+                             ,
+                             tmp_sign
 #endif  // CONFIG_CONTEXT_DERIVATION
-                               ,
-                               plane) +
-            new_eob_rate;
-        int rate_Q0_b =
-            get_coeff_cost_eob(blk_pos, pqData.absLevel[2],
-                               (qcoeff[blk_pos] < 0), new_eob_ctx,
-                               txb_ctx->dc_sign_ctx, txb_costs, bwl, tx_class
+                             ,
+                             plane) +
+          new_eob_rate;
+      int rate_Q0_b =
+          get_coeff_cost_eob(blk_pos, pqData.absLevel[2], (qcoeff[blk_pos] < 0),
+                             new_eob_ctx, txb_ctx->dc_sign_ctx, txb_costs, bwl,
+                             tx_class
 #if CONFIG_CONTEXT_DERIVATION
-                               ,
-                               tmp_sign
+                             ,
+                             tmp_sign
 #endif  // CONFIG_CONTEXT_DERIVATION
-                               ,
-                               plane) +
-            new_eob_rate;
-        const int state0 = next_st[0][0];
-        const int state1 = next_st[0][1];
-        decide(0, pqData.deltaDist[0], pqData.deltaDist[2], rdmult, rate_Q0_a,
-               rate_Q0_b, INT32_MAX >> 1, pqData.absLevel[0],
-               pqData.absLevel[2], limits, 0, -1, &decision[state0],
-               &decision[state1]);
-      }
+                             ,
+                             plane) +
+          new_eob_rate;
+      const int state0 = next_st[0][0];
+      const int state1 = next_st[0][1];
+      decide(0, pqData.deltaDist[0], pqData.deltaDist[2], rdmult, rate_Q0_a,
+             rate_Q0_b, INT32_MAX >> 1, pqData.absLevel[0], pqData.absLevel[2],
+             limits, 0, -1, &decision[state0], &decision[state1]);
     }
 
     // copy corresponding context from previous level buffer
diff --git a/av1/encoder/x86/trellis_quant_avx2.c b/av1/encoder/x86/trellis_quant_avx2.c
index 6ef844d..b11348e 100644
--- a/av1/encoder/x86/trellis_quant_avx2.c
+++ b/av1/encoder/x86/trellis_quant_avx2.c
@@ -261,7 +261,7 @@
   return high_range;
 }
 
-static int get_golomb_cost(int abs_qc) {
+static INLINE int get_golomb_cost(int abs_qc) {
 #if NEWHR
   if (abs_qc >= NUM_BASE_LEVELS + COEFF_BASE_RANGE) {
     const int r = 1 + get_high_range(abs_qc, 0);
@@ -278,30 +278,39 @@
   return 0;
 }
 
-static int get_br_cost(tran_low_t level, const int *coeff_lps) {
+static INLINE int get_br_cost(tran_low_t level, const int *coeff_lps) {
   const int base_range = AOMMIN(level - 1 - NUM_BASE_LEVELS, COEFF_BASE_RANGE);
   return coeff_lps[base_range] + get_golomb_cost(level);
 }
 
-static int get_coeff_mid_cost_def(tran_low_t abs_qc, int coeff_ctx,
-                                  const LV_MAP_COEFF_COST *txb_costs) {
+static INLINE int get_mid_cost_def(tran_low_t abs_qc, int coeff_ctx,
+                                   const LV_MAP_COEFF_COST *txb_costs,
+                                   int plane, int t_sign, int sign) {
   int cost = 0;
+  if (plane == AOM_PLANE_V) {
+    cost += txb_costs->v_ac_sign_cost[t_sign][sign] - av1_cost_literal(1);
+  }
   if (abs_qc > NUM_BASE_LEVELS) {
     int mid_ctx = coeff_ctx >> 4;
-    cost += get_br_cost(abs_qc, txb_costs->lps_cost[mid_ctx]);
+    if (plane == 0) {
+      cost += get_br_cost(abs_qc, txb_costs->lps_cost[mid_ctx]);
+    } else {
+      cost += get_br_cost(abs_qc, txb_costs->lps_cost_uv[mid_ctx]);
+    }
   }
   return cost;
 }
 
-void av1_get_rate_dist_def_avx2(const struct LV_MAP_COEFF_COST *txb_costs,
-                                const struct prequant_t *pq,
-                                const uint8_t coeff_ctx[2 * TOTALSTATES],
-                                int diag_ctx, int plane,
-                                int32_t rate_zero[TOTALSTATES],
-                                int32_t rate[2 * TOTALSTATES],
-                                int64_t dist[2 * TOTALSTATES]) {
-  assert(plane == 0);
-  (void)plane;
+void av1_get_rate_dist_def_luma_avx2(const struct LV_MAP_COEFF_COST *txb_costs,
+                                     const struct prequant_t *pq,
+                                     const uint8_t coeff_ctx[TOTALSTATES + 4],
+                                     int diag_ctx,
+                                     int32_t rate_zero[TOTALSTATES],
+                                     int32_t rate[2 * TOTALSTATES],
+                                     int64_t dist[2 * TOTALSTATES]) {
+  const int32_t(*cost_low)[4][SIG_COEF_CONTEXTS] = txb_costs->base_cost_low;
+  const uint16_t(*cost_low_tbl)[SIG_COEF_CONTEXTS][DQ_CTXS][2] =
+      txb_costs->base_cost_low_tbl;
   const tran_low_t *absLevel = pq->absLevel;
   const int64_t *deltaDist = pq->deltaDist;
 
@@ -319,9 +328,10 @@
   // Calc zero coeff costs.
   __m256i zero = _mm256_setzero_si256();
   __m256i cost_zero_dq0 =
-      _mm256_lddqu_si256((__m256i *)&txb_costs->base_cost_low[0][0][diag_ctx]);
+      _mm256_lddqu_si256((__m256i *)&cost_low[0][0][diag_ctx]);
   __m256i cost_zero_dq1 =
-      _mm256_lddqu_si256((__m256i *)&txb_costs->base_cost_low[1][0][diag_ctx]);
+      _mm256_lddqu_si256((__m256i *)&cost_low[1][0][diag_ctx]);
+
   __m256i ctx = _mm256_castsi128_si256(_mm_loadu_si64(coeff_ctx));
   ctx = _mm256_unpacklo_epi8(ctx, zero);
   ctx = _mm256_shuffle_epi32(ctx, 0xD8);
@@ -346,14 +356,10 @@
     int ctx1 = diag_ctx + (coeff_ctx[j + 1] & 15);
     int ctx2 = diag_ctx + (coeff_ctx[j + 2] & 15);
     int ctx3 = diag_ctx + (coeff_ctx[j + 3] & 15);
-    __m128i rate_01 =
-        _mm_loadu_si64(&txb_costs->base_cost_low_tbl[idx][ctx0][0]);
-    __m128i rate_23 =
-        _mm_loadu_si64(&txb_costs->base_cost_low_tbl[idx][ctx1][0]);
-    __m128i rate_45 =
-        _mm_loadu_si64(&txb_costs->base_cost_low_tbl[idx][ctx2][1]);
-    __m128i rate_67 =
-        _mm_loadu_si64(&txb_costs->base_cost_low_tbl[idx][ctx3][1]);
+    __m128i rate_01 = _mm_loadu_si64(&cost_low_tbl[idx][ctx0][0]);
+    __m128i rate_23 = _mm_loadu_si64(&cost_low_tbl[idx][ctx1][0]);
+    __m128i rate_45 = _mm_loadu_si64(&cost_low_tbl[idx][ctx2][1]);
+    __m128i rate_67 = _mm_loadu_si64(&cost_low_tbl[idx][ctx3][1]);
     __m128i rate_0123 = _mm_unpacklo_epi32(rate_01, rate_23);
     __m128i rate_4567 = _mm_unpacklo_epi32(rate_45, rate_67);
     __m128i c_zero = _mm_setzero_si128();
@@ -365,24 +371,201 @@
 
   // Calc coeff mid and high range cost.
   if (idx > 0) {
-    rate[0] += get_coeff_mid_cost_def(absLevel[0], coeff_ctx[0], txb_costs);
-    rate[1] += get_coeff_mid_cost_def(absLevel[2], coeff_ctx[0], txb_costs);
-    rate[2] += get_coeff_mid_cost_def(absLevel[0], coeff_ctx[1], txb_costs);
-    rate[3] += get_coeff_mid_cost_def(absLevel[2], coeff_ctx[1], txb_costs);
-    rate[4] += get_coeff_mid_cost_def(absLevel[1], coeff_ctx[2], txb_costs);
-    rate[5] += get_coeff_mid_cost_def(absLevel[3], coeff_ctx[2], txb_costs);
-    rate[6] += get_coeff_mid_cost_def(absLevel[1], coeff_ctx[3], txb_costs);
-    rate[7] += get_coeff_mid_cost_def(absLevel[3], coeff_ctx[3], txb_costs);
-#if MORESTATES
-    rate[8] += get_coeff_mid_cost_def(absLevel[0], coeff_ctx[4], txb_costs);
-    rate[9] += get_coeff_mid_cost_def(absLevel[2], coeff_ctx[4], txb_costs);
-    rate[10] += get_coeff_mid_cost_def(absLevel[0], coeff_ctx[5], txb_costs);
-    rate[11] += get_coeff_mid_cost_def(absLevel[2], coeff_ctx[5], txb_costs);
-    rate[12] += get_coeff_mid_cost_def(absLevel[1], coeff_ctx[6], txb_costs);
-    rate[13] += get_coeff_mid_cost_def(absLevel[3], coeff_ctx[6], txb_costs);
-    rate[14] += get_coeff_mid_cost_def(absLevel[1], coeff_ctx[7], txb_costs);
-    rate[15] += get_coeff_mid_cost_def(absLevel[3], coeff_ctx[7], txb_costs);
+    for (int i = 0; i < TOTALSTATES; i++) {
+      int a0 = i & 2 ? 1 : 0;
+      int a1 = a0 + 2;
+      int mid_cost0 =
+          get_mid_cost_def(absLevel[a0], coeff_ctx[i], txb_costs, 0, 0, 0);
+      int mid_cost1 =
+          get_mid_cost_def(absLevel[a1], coeff_ctx[i], txb_costs, 0, 0, 0);
+      rate[2 * i] += mid_cost0;
+      rate[2 * i + 1] += mid_cost1;
+    }
+  }
+}
+
+static AOM_FORCE_INLINE int get_golomb_cost_lf(int abs_qc) {
+#if NEWHR
+  if (abs_qc >= LF_NUM_BASE_LEVELS + COEFF_BASE_RANGE) {
+    const int r = 1 + get_high_range(abs_qc, 1);
+    const int length = get_msb(r) + 1;
+    return av1_cost_literal(2 * length - 1);
+  }
+#else
+  if (abs_qc >= 1 + LF_NUM_BASE_LEVELS + COEFF_BASE_RANGE) {
+    const int r = abs_qc - COEFF_BASE_RANGE - LF_NUM_BASE_LEVELS;
+    const int length = get_msb(r) + 1;
+    return av1_cost_literal(2 * length - 1);
+  }
 #endif
+  return 0;
+}
+
+static AOM_FORCE_INLINE int get_br_lf_cost(tran_low_t level,
+                                           const int *coeff_lps) {
+  const int base_range =
+      AOMMIN(level - 1 - LF_NUM_BASE_LEVELS, COEFF_BASE_RANGE);
+  return coeff_lps[base_range] + get_golomb_cost_lf(level);
+}
+
+static int get_mid_cost_lf_dc(int ci, tran_low_t abs_qc, int sign,
+                              int coeff_ctx, int dc_sign_ctx,
+                              const LV_MAP_COEFF_COST *txb_costs,
+                              const int32_t *tmp_sign, int plane) {
+  int cost = 0;
+  int mid_ctx = coeff_ctx >> 4;
+  const int dc_ph_group = 0;    // PH disabled
+  cost -= av1_cost_literal(1);  // Remove previously added sign cost.
+  if (plane == AOM_PLANE_V)
+    cost += txb_costs->v_dc_sign_cost[tmp_sign[ci]][dc_sign_ctx][sign];
+  else
+    cost += txb_costs->dc_sign_cost[dc_ph_group][dc_sign_ctx][sign];
+  if (plane > 0) {
+    if (abs_qc > LF_NUM_BASE_LEVELS) {
+      cost += get_br_lf_cost(abs_qc, txb_costs->lps_lf_cost_uv[mid_ctx]);
+    }
+  } else {
+    if (abs_qc > LF_NUM_BASE_LEVELS) {
+      cost += get_br_lf_cost(abs_qc, txb_costs->lps_lf_cost[mid_ctx]);
+    }
+  }
+  return cost;
+}
+
+static int get_mid_cost_lf(tran_low_t abs_qc, int coeff_ctx,
+                           const LV_MAP_COEFF_COST *txb_costs, int plane) {
+  int cost = 0;
+  int mid_ctx = coeff_ctx >> 4;
+#if 1
+  assert(plane == 0);
+  (void)plane;
+  if (abs_qc > LF_NUM_BASE_LEVELS) {
+    cost += get_br_lf_cost(abs_qc, txb_costs->lps_lf_cost[mid_ctx]);
+  }
+#else
+  if (plane > 0) {
+    if (abs_qc > LF_NUM_BASE_LEVELS) {
+      cost += get_br_lf_cost(abs_qc, txb_costs->lps_lf_cost_uv[mid_ctx]);
+    }
+  } else {
+    if (abs_qc > LF_NUM_BASE_LEVELS) {
+      cost += get_br_lf_cost(abs_qc, txb_costs->lps_lf_cost[mid_ctx]);
+    }
+  }
+#endif
+  return cost;
+}
+
+void av1_get_rate_dist_lf_avx2(
+    const struct LV_MAP_COEFF_COST *txb_costs, const struct prequant_t *pq,
+    const uint8_t coeff_ctx[TOTALSTATES + 4], int diag_ctx, int dc_sign_ctx,
+    int32_t *tmp_sign, int bwl, TX_CLASS tx_class, int plane, int blk_pos,
+    int coeff_sign, int32_t rate_zero[TOTALSTATES],
+    int32_t rate[2 * TOTALSTATES], int64_t dist[2 * TOTALSTATES]) {
+#define Z -1
+  static const int8_t kShuf[2][32] = {
+    { 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15,
+      0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15 },
+    { 0, 8,  Z, Z, 1, 9,  Z, Z, 2, 10, Z, Z, 3, 11, Z, Z,
+      4, 12, Z, Z, 5, 13, Z, Z, 6, 14, Z, Z, 7, 15, Z, Z }
+  };
+  const uint16_t(*cost_low)[LF_BASE_SYMBOLS][LF_SIG_COEF_CONTEXTS] =
+      plane ? txb_costs->base_lf_cost_uv_low : txb_costs->base_lf_cost_low;
+  const uint16_t(*cost_low_tbl)[LF_SIG_COEF_CONTEXTS][DQ_CTXS][2] =
+      plane ? txb_costs->base_lf_cost_uv_low_tbl
+            : txb_costs->base_lf_cost_low_tbl;
+  const tran_low_t *absLevel = pq->absLevel;
+  const int64_t *deltaDist = pq->deltaDist;
+
+  // Copy distortion stats.
+  __m256i delta_dist = _mm256_lddqu_si256((__m256i *)deltaDist);
+  __m256i dist02 = _mm256_permute4x64_epi64(delta_dist, 0x88);
+  __m256i dist13 = _mm256_permute4x64_epi64(delta_dist, 0xDD);
+  _mm256_storeu_si256((__m256i *)&dist[0], dist02);
+  _mm256_storeu_si256((__m256i *)&dist[4], dist13);
+#if MORESTATES
+  _mm256_storeu_si256((__m256i *)&dist[8], dist02);
+  _mm256_storeu_si256((__m256i *)&dist[12], dist13);
+#endif
+
+  // Calc zero coeff costs.
+  __m256i cost_zero_dq0 =
+      _mm256_lddqu_si256((__m256i *)&cost_low[0][0][diag_ctx]);
+  __m256i cost_zero_dq1 =
+      _mm256_lddqu_si256((__m256i *)&cost_low[1][0][diag_ctx]);
+  __m256i shuf = _mm256_lddqu_si256((__m256i *)kShuf[0]);
+  cost_zero_dq0 = _mm256_shuffle_epi8(cost_zero_dq0, shuf);
+  cost_zero_dq1 = _mm256_shuffle_epi8(cost_zero_dq1, shuf);
+  __m256i cost_dq0 = _mm256_permute4x64_epi64(cost_zero_dq0, 0xD8);
+  __m256i cost_dq1 = _mm256_permute4x64_epi64(cost_zero_dq1, 0xD8);
+  __m256i ctx = _mm256_castsi128_si256(_mm_loadu_si64(coeff_ctx));
+  __m256i fifteen = _mm256_set1_epi8(15);
+  __m256i base_ctx = _mm256_and_si256(ctx, fifteen);
+  base_ctx = _mm256_permute4x64_epi64(base_ctx, 0);
+  __m256i ratez_dq0 = _mm256_shuffle_epi8(cost_dq0, base_ctx);
+  __m256i ratez_dq1 = _mm256_shuffle_epi8(cost_dq1, base_ctx);
+  __m256i ratez = _mm256_blend_epi16(ratez_dq0, ratez_dq1, 0xAA);
+  ratez = _mm256_permute4x64_epi64(ratez, 0x88);
+  __m256i shuf1 = _mm256_lddqu_si256((__m256i *)kShuf[1]);
+  ratez = _mm256_shuffle_epi8(ratez, shuf1);
+#if MORESTATES
+  _mm256_storeu_si256((__m256i *)&rate_zero[0], ratez);
+#else
+  _mm_storeu_si128((__m128i *)&rate_zero[0], _mm256_castsi256_si128(ratez));
+#endif
+
+  // Calc coeff_base rate.
+  int idx = AOMMIN(pq->qIdx - 1, 8);
+  for (int i = 0; i < TOTALSTATES / 4; i++) {
+    int j = 4 * i;
+    int ctx0 = diag_ctx + (coeff_ctx[j + 0] & 15);
+    int ctx1 = diag_ctx + (coeff_ctx[j + 1] & 15);
+    int ctx2 = diag_ctx + (coeff_ctx[j + 2] & 15);
+    int ctx3 = diag_ctx + (coeff_ctx[j + 3] & 15);
+    __m128i rate_01 = _mm_loadu_si64(&cost_low_tbl[idx][ctx0][0]);
+    __m128i rate_23 = _mm_loadu_si64(&cost_low_tbl[idx][ctx1][0]);
+    __m128i rate_45 = _mm_loadu_si64(&cost_low_tbl[idx][ctx2][1]);
+    __m128i rate_67 = _mm_loadu_si64(&cost_low_tbl[idx][ctx3][1]);
+    __m128i rate_0123 = _mm_unpacklo_epi32(rate_01, rate_23);
+    __m128i rate_4567 = _mm_unpacklo_epi32(rate_45, rate_67);
+    __m128i c_zero = _mm_setzero_si128();
+    rate_0123 = _mm_unpacklo_epi16(rate_0123, c_zero);
+    rate_4567 = _mm_unpacklo_epi16(rate_4567, c_zero);
+    _mm_storeu_si128((__m128i *)&rate[8 * i], rate_0123);
+    _mm_storeu_si128((__m128i *)&rate[8 * i + 4], rate_4567);
+  }
+
+  const int row = blk_pos >> bwl;
+  const int col = blk_pos - (row << bwl);
+  const bool dc_2dtx = (blk_pos == 0);
+  const bool dc_hor = (col == 0) && tx_class == TX_CLASS_HORIZ;
+  const bool dc_ver = (row == 0) && tx_class == TX_CLASS_VERT;
+  const bool is_dc_coeff = dc_2dtx || dc_hor || dc_ver;
+  if (is_dc_coeff) {
+    for (int i = 0; i < TOTALSTATES; i++) {
+      int a0 = i & 2 ? 1 : 0;
+      int a1 = a0 + 2;
+      int mid_cost0 =
+          get_mid_cost_lf_dc(blk_pos, absLevel[a0], coeff_sign, coeff_ctx[i],
+                             dc_sign_ctx, txb_costs, tmp_sign, plane);
+      int mid_cost1 =
+          get_mid_cost_lf_dc(blk_pos, absLevel[a1], coeff_sign, coeff_ctx[i],
+                             dc_sign_ctx, txb_costs, tmp_sign, plane);
+      rate[2 * i] += mid_cost0;
+      rate[2 * i + 1] += mid_cost1;
+    }
+  } else if (idx > 4) {
+    assert(plane == 0);
+    for (int i = 0; i < TOTALSTATES; i++) {
+      int a0 = i & 2 ? 1 : 0;
+      int a1 = a0 + 2;
+      int mid_cost0 =
+          get_mid_cost_lf(absLevel[a0], coeff_ctx[i], txb_costs, plane);
+      int mid_cost1 =
+          get_mid_cost_lf(absLevel[a1], coeff_ctx[i], txb_costs, plane);
+      rate[2 * i] += mid_cost0;
+      rate[2 * i + 1] += mid_cost1;
+    }
   }
 }
 
@@ -404,3 +587,83 @@
     tcq_ctx[i].lev[scan_idx] = AOMMIN(absLevel, INT8_MAX);
   }
 }
+
+void av1_get_rate_dist_def_chroma_avx2(
+    const struct LV_MAP_COEFF_COST *txb_costs, const struct prequant_t *pq,
+    const uint8_t coeff_ctx[TOTALSTATES + 4], int diag_ctx, int plane,
+    int t_sign, int sign, int32_t rate_zero[TOTALSTATES],
+    int32_t rate[2 * TOTALSTATES], int64_t dist[2 * TOTALSTATES]) {
+  const int32_t(*cost_low)[4][SIG_COEF_CONTEXTS] = txb_costs->base_cost_uv_low;
+  const uint16_t(*cost_low_tbl)[SIG_COEF_CONTEXTS][DQ_CTXS][2] =
+      txb_costs->base_cost_uv_low_tbl;
+  const tran_low_t *absLevel = pq->absLevel;
+  const int64_t *deltaDist = pq->deltaDist;
+
+  // Copy distortion stats.
+  __m256i delta_dist = _mm256_lddqu_si256((__m256i *)deltaDist);
+  __m256i dist02 = _mm256_permute4x64_epi64(delta_dist, 0x88);
+  __m256i dist13 = _mm256_permute4x64_epi64(delta_dist, 0xDD);
+  _mm256_storeu_si256((__m256i *)&dist[0], dist02);
+  _mm256_storeu_si256((__m256i *)&dist[4], dist13);
+#if MORESTATES
+  _mm256_storeu_si256((__m256i *)&dist[8], dist02);
+  _mm256_storeu_si256((__m256i *)&dist[12], dist13);
+#endif
+
+  // Calc zero coeff costs.
+  __m256i zero = _mm256_setzero_si256();
+  __m256i cost_zero_dq0 =
+      _mm256_lddqu_si256((__m256i *)&cost_low[0][0][diag_ctx]);
+  __m256i cost_zero_dq1 =
+      _mm256_lddqu_si256((__m256i *)&cost_low[1][0][diag_ctx]);
+  __m256i ctx = _mm256_castsi128_si256(_mm_loadu_si64(coeff_ctx));
+  ctx = _mm256_unpacklo_epi8(ctx, zero);
+  ctx = _mm256_shuffle_epi32(ctx, 0xD8);
+  __m256i ctx_dq0 = _mm256_unpacklo_epi16(ctx, zero);
+  __m256i ctx_dq1 = _mm256_unpackhi_epi16(ctx, zero);
+  __m256i ratez_dq0 = _mm256_permutevar8x32_epi32(cost_zero_dq0, ctx_dq0);
+  __m256i ratez_dq1 = _mm256_permutevar8x32_epi32(cost_zero_dq1, ctx_dq1);
+  __m256i ratez_0123 = _mm256_unpacklo_epi64(ratez_dq0, ratez_dq1);
+  _mm_storeu_si128((__m128i *)&rate_zero[0],
+                   _mm256_castsi256_si128(ratez_0123));
+#if MORESTATES
+  __m256i ratez_4567 = _mm256_unpackhi_epi64(ratez_dq0, ratez_dq1);
+  _mm_storeu_si128((__m128i *)&rate_zero[4],
+                   _mm256_castsi256_si128(ratez_4567));
+#endif
+
+  // Calc coeff_base rate.
+  int idx = AOMMIN(pq->qIdx - 1, 4);
+  for (int i = 0; i < TOTALSTATES / 4; i++) {
+    int j = 4 * i;
+    int ctx0 = diag_ctx + (coeff_ctx[j + 0] & 15);
+    int ctx1 = diag_ctx + (coeff_ctx[j + 1] & 15);
+    int ctx2 = diag_ctx + (coeff_ctx[j + 2] & 15);
+    int ctx3 = diag_ctx + (coeff_ctx[j + 3] & 15);
+    __m128i rate_01 = _mm_loadu_si64(&cost_low_tbl[idx][ctx0][0]);
+    __m128i rate_23 = _mm_loadu_si64(&cost_low_tbl[idx][ctx1][0]);
+    __m128i rate_45 = _mm_loadu_si64(&cost_low_tbl[idx][ctx2][1]);
+    __m128i rate_67 = _mm_loadu_si64(&cost_low_tbl[idx][ctx3][1]);
+    __m128i rate_0123 = _mm_unpacklo_epi32(rate_01, rate_23);
+    __m128i rate_4567 = _mm_unpacklo_epi32(rate_45, rate_67);
+    __m128i c_zero = _mm_setzero_si128();
+    rate_0123 = _mm_unpacklo_epi16(rate_0123, c_zero);
+    rate_4567 = _mm_unpacklo_epi16(rate_4567, c_zero);
+    _mm_storeu_si128((__m128i *)&rate[8 * i], rate_0123);
+    _mm_storeu_si128((__m128i *)&rate[8 * i + 4], rate_4567);
+  }
+
+  // Calc coeff mid and high range cost.
+  if (idx > 0 || plane) {
+    for (int i = 0; i < TOTALSTATES; i++) {
+      int a0 = i & 2 ? 1 : 0;
+      int a1 = a0 + 2;
+      int mid_cost0 = get_mid_cost_def(absLevel[a0], coeff_ctx[i], txb_costs,
+                                       plane, t_sign, sign);
+      int mid_cost1 = get_mid_cost_def(absLevel[a1], coeff_ctx[i], txb_costs,
+                                       plane, t_sign, sign);
+      rate[2 * i] += mid_cost0;
+      rate[2 * i + 1] += mid_cost1;
+    }
+  }
+}