[tcq] Enable 8-state command line options
[tcq] Organize tcq params
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index 9f290ab..7fe6cc2 100644
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -322,31 +322,41 @@
 
   # trellis quant
   if (aom_config("CONFIG_DQ") eq "yes") {
-    add_proto qw/void av1_decide_states/, "const struct tcq_node_t *prev, const struct tcq_rate_t *rd, const struct prequant_t *pq, int limits, int tru_eob, int64_t rdmult, struct tcq_node_t *decision";
+    add_proto qw/void av1_decide_states/, "const struct tcq_node_t *prev, const struct tcq_rate_t *rd, const struct prequant_t *pq, int n_states, int limits, int tru_eob, int64_t rdmult, struct tcq_node_t *decision";
     specialize qw/av1_decide_states avx2/;
+    add_proto qw/void av1_decide_states_st4/, "const struct tcq_node_t *prev, const struct tcq_rate_t *rd, const struct prequant_t *pq, int n_states, int limits, int tru_eob, int64_t rdmult, struct tcq_node_t *decision";
+    specialize qw/av1_decide_states_st4 avx2/;
     add_proto qw/void av1_pre_quant/, "tran_low_t tqc, struct prequant_t* pqData, const int32_t* quant_ptr, int dqv, int log_scale, int scan_pos";
     specialize qw/av1_pre_quant avx2/;
     add_proto qw/void av1_calc_diag_ctx/, "int scan_hi, int scan_lo, int bwl, const uint8_t *prev_levels, const int16_t* scan, uint8_t *ctx";
     specialize qw/av1_calc_diag_ctx avx2/;
-    add_proto qw/void av1_get_rate_dist_def_luma/, "const struct LV_MAP_COEFF_COST* txb_costs, const struct prequant_t *pq, const struct tcq_coeff_ctx_t *coeff_ctx, int blk_pos, int bwl, TX_CLASS tx_class, int diag_ctx, int eob_rate, struct tcq_rate_t *rd";
+
+    add_proto qw/void av1_get_rate_dist_def_luma/, "const struct LV_MAP_COEFF_COST* txb_costs, const struct prequant_t *pq, const struct tcq_coeff_ctx_t *coeff_ctx, int blk_pos, int bwl, TX_CLASS tx_class, int diag_ctx, int eob_rate, int n_states, struct tcq_rate_t *rd";
     specialize qw/av1_get_rate_dist_def_luma avx2/;
-    add_proto qw/void av1_get_rate_dist_def_chroma/, "const struct LV_MAP_COEFF_COST* txb_costs, const struct prequant_t *pq, const struct tcq_coeff_ctx_t *coeff_ctx, int blk_pos, int bwl, TX_CLASS tx_class, int diag_ctx, int eob_rate, int plane, int t_sign, int sign, struct tcq_rate_t *rd";
+    add_proto qw/void av1_get_rate_dist_def_luma_st4/, "const struct LV_MAP_COEFF_COST* txb_costs, const struct prequant_t *pq, const struct tcq_coeff_ctx_t *coeff_ctx, int blk_pos, int bwl, TX_CLASS tx_class, int diag_ctx, int eob_rate, int n_states, struct tcq_rate_t *rd";
+    specialize qw/av1_get_rate_dist_def_luma_st4 avx2/;
+    add_proto qw/void av1_get_rate_dist_def_chroma/, "const struct LV_MAP_COEFF_COST* txb_costs, const struct prequant_t *pq, const struct tcq_coeff_ctx_t *coeff_ctx, int blk_pos, int bwl, TX_CLASS tx_class, int diag_ctx, int eob_rate, int plane, int t_sign, int sign, int n_states, struct tcq_rate_t *rd";
     specialize qw/av1_get_rate_dist_def_chroma avx2/;
-    add_proto qw/void av1_get_rate_dist_lf_luma/, "const struct LV_MAP_COEFF_COST *txb_costs, const struct prequant_t *pq, const struct tcq_coeff_ctx_t *coeff_ctx, int blk_pos, int diag_ctx, int eob_rate, int dc_sign_ctx, int32_t *tmp_sign, int bwl, TX_CLASS tx_class, int coeff_sign, struct tcq_rate_t *rd";
+    add_proto qw/void av1_get_rate_dist_lf_luma/, "const struct LV_MAP_COEFF_COST *txb_costs, const struct prequant_t *pq, const struct tcq_coeff_ctx_t *coeff_ctx, int blk_pos, int diag_ctx, int eob_rate, int dc_sign_ctx, const int32_t *tmp_sign, int bwl, TX_CLASS tx_class, int coeff_sign, int n_states, struct tcq_rate_t *rd";
     specialize qw/av1_get_rate_dist_lf_luma avx2/;
-    add_proto qw/void av1_get_rate_dist_lf_chroma/, "const struct LV_MAP_COEFF_COST *txb_costs, const struct prequant_t *pq, const struct tcq_coeff_ctx_t *coeff_ctx, int blk_pos, int diag_ctx, int eob_rate, int dc_sign_ctx, int32_t *tmp_sign, int bwl, TX_CLASS tx_class, int plane, int coeff_sign, struct tcq_rate_t *rd";
+    add_proto qw/void av1_get_rate_dist_lf_luma_st4/, "const struct LV_MAP_COEFF_COST *txb_costs, const struct prequant_t *pq, const struct tcq_coeff_ctx_t *coeff_ctx, int blk_pos, int diag_ctx, int eob_rate, int dc_sign_ctx, const int32_t *tmp_sign, int bwl, TX_CLASS tx_class, int coeff_sign, int n_states, struct tcq_rate_t *rd";
+    specialize qw/av1_get_rate_dist_lf_luma_st4 avx2/;
+    add_proto qw/void av1_get_rate_dist_lf_chroma/, "const struct LV_MAP_COEFF_COST *txb_costs, const struct prequant_t *pq, const struct tcq_coeff_ctx_t *coeff_ctx, int blk_pos, int diag_ctx, int eob_rate, int dc_sign_ctx, const int32_t *tmp_sign, int bwl, TX_CLASS tx_class, int plane, int coeff_sign, int n_states, struct tcq_rate_t *rd";
     specialize qw/av1_get_rate_dist_lf_chroma avx2/;
-    add_proto qw/void av1_update_states/, "struct tcq_node_t *decision, int scan_idx, const struct tcq_ctx_t *cur_ctx, struct tcq_ctx_t *nxt_ctx";
-    #specialize qw/av1_update_states avx2/;
+
+    add_proto qw/void av1_update_states/, "struct tcq_node_t *decision, int scan_idx, int n_states, const struct tcq_ctx_t *cur_ctx, struct tcq_ctx_t *nxt_ctx";
+    specialize qw/av1_update_states avx2/;
     add_proto qw/void av1_init_lf_ctx/, "const uint8_t *lev, int scan_hi, int bwl, struct tcq_lf_ctx_t *lf_ctx";
     specialize qw/av1_init_lf_ctx avx2/;
-    add_proto qw/void av1_calc_lf_ctx/, "const struct tcq_lf_ctx_t *lf_ctx, int scan_pos, struct tcq_coeff_ctx_t *coeff_ctx";
-    specialize qw/av1_calc_lf_ctx avx2/;
-    add_proto qw/void av1_update_lf_ctx/, "const struct tcq_node_t *decision, struct tcq_lf_ctx_t *lf_ctx";
+    add_proto qw/void av1_calc_lf_ctx_st4/, "const struct tcq_lf_ctx_t *lf_ctx, int scan_pos, struct tcq_coeff_ctx_t *coeff_ctx";
+    specialize qw/av1_calc_lf_ctx_st4 avx2/;
+    add_proto qw/void av1_calc_lf_ctx_st8/, "const struct tcq_lf_ctx_t *lf_ctx, int scan_pos, struct tcq_coeff_ctx_t *coeff_ctx";
+    specialize qw/av1_calc_lf_ctx_st8 avx2/;
+    add_proto qw/void av1_update_lf_ctx/, "const struct tcq_node_t *decision, int n_states, struct tcq_lf_ctx_t *lf_ctx";
     specialize qw/av1_update_lf_ctx avx2/;
     add_proto qw/void av1_calc_block_eob_rate/, "struct macroblock *x, int plane, TX_SIZE tx_size, int eob, uint16_t *block_eob_rate";
     specialize qw/av1_calc_block_eob_rate avx2/;
-    add_proto qw/int av1_find_best_path/, "const struct tcq_node_t *trellis, const int16_t *scan, const int32_t *dequant, const qm_val_t *iqmatrix, const tran_low_t *tcoeff, int first_scan_pos, int log_scale, tran_low_t *qcoeff, tran_low_t *dqcoeff, int *min_rate, int64_t *min_cost";
+    add_proto qw/int av1_find_best_path/, "const struct tcq_node_t *trellis, int n_states_log2, const int16_t *scan, const int32_t *dequant, const qm_val_t *iqmatrix, const tran_low_t *tcoeff, int first_scan_pos, int log_scale, tran_low_t *qcoeff, tran_low_t *dqcoeff, int *min_rate, int64_t *min_cost";
     specialize qw/av1_find_best_path avx2/;
   }
 
diff --git a/av1/common/quant_common.h b/av1/common/quant_common.h
index f7613d9..d10ffe6 100644
--- a/av1/common/quant_common.h
+++ b/av1/common/quant_common.h
@@ -27,20 +27,14 @@
 #define TCQ_HDR_FLAG 1  // Enable through header flag(s)
 #define DQENABLE 0      // Determine whether to use DQ by dq_enable()
 #define NEWQINDEX 1     // QP shift
-#define MORESTATES 0    // 1: 8-state; 0: 4-state
 #define NEWHR 1         // 1:parity is determined by (base + LR)
 #else
 #define TCQ_HDR_FLAG 0
-#define DQENABLE 0    // Determine whether to use DQ by dq_enable()
-#define NEWQINDEX 0   // QP shift
-#define MORESTATES 0  // 1: 8-state; 0: 4-state
+#define DQENABLE 0   // Determine whether to use DQ by dq_enable()
+#define NEWQINDEX 0  // QP shift
 #define NEWHR 0
 #endif
-#if MORESTATES
-#define TOTALSTATES 8
-#else
-#define TOTALSTATES 4
-#endif
+#define TCQ_MAX_STATES 8
 
 #define PHTHRESH 4
 #define MINQ 0
diff --git a/av1/encoder/trellis_quant.c b/av1/encoder/trellis_quant.c
index 72cdcd9..0249ea7 100644
--- a/av1/encoder/trellis_quant.c
+++ b/av1/encoder/trellis_quant.c
@@ -24,6 +24,25 @@
 #include "av1/encoder/rdopt.h"
 #include "av1/encoder/tokenize.h"
 
+typedef void (*DecideStateFnc)(const struct tcq_node_t *prev,
+                               const struct tcq_rate_t *rd,
+                               const struct prequant_t *pq, int n_states,
+                               int limits, int try_eob, int64_t rdmult,
+                               struct tcq_node_t *decision);
+typedef void (*GetLfLumaRateDistFnc)(const struct LV_MAP_COEFF_COST *txb_costs,
+                                     const struct prequant_t *pq,
+                                     const struct tcq_coeff_ctx_t *coeff_ctx,
+                                     int blk_pos, int diag_ctx, int eob_rate,
+                                     int dc_sign_ctx, const int32_t *tmp_sign,
+                                     int bwl, TX_CLASS tx_class, int coeff_sign,
+                                     int n_states, struct tcq_rate_t *rd);
+typedef void (*GetDefLumaRateDistFnc)(const struct LV_MAP_COEFF_COST *txb_costs,
+                                      const struct prequant_t *pq,
+                                      const struct tcq_coeff_ctx_t *coeff_ctx,
+                                      int blk_pos, int bwl, TX_CLASS tx_class,
+                                      int diag_ctx, int eob_rate, int n_states,
+                                      struct tcq_rate_t *rd);
+
 typedef struct {
   uint8_t *base;
   int bufsize;
@@ -46,19 +65,9 @@
   return &lev->base[(2 * st + !lev->idx) * lev->bufsize];
 }
 
-#if MORESTATES
-static const uint8_t next_st[TOTALSTATES][2] = { { 0, 4 }, { 4, 0 }, { 1, 5 },
-                                                 { 5, 1 }, { 6, 2 }, { 2, 6 },
-                                                 { 7, 3 }, { 3, 7 } };
-#else
-static const uint8_t next_st[TOTALSTATES][2] = {
-  { 0, 2 }, { 2, 0 }, { 1, 3 }, { 3, 1 }
-};
-#endif
-
-static AOM_INLINE void init_tcq_decision(tcq_node_t *decision) {
+static AOM_INLINE void init_tcq_decision(tcq_node_t *decision, int n_states) {
   static const tcq_node_t def = { INT64_MAX >> 10, INT32_MAX >> 1, -1, -2 };
-  for (int state = 0; state < TOTALSTATES; state++) {
+  for (int state = 0; state < n_states; state++) {
     memcpy(&decision[state], &def, sizeof(def));
   }
 }
@@ -376,7 +385,7 @@
 static INLINE int get_coeff_cost_general(
     int ci, tran_low_t abs_qc, int sign, int coeff_ctx, int mid_ctx,
     int dc_sign_ctx, const LV_MAP_COEFF_COST *txb_costs, int bwl,
-    TX_CLASS tx_class, int32_t *tmp_sign, int plane, int limits, int dq) {
+    TX_CLASS tx_class, const int32_t *tmp_sign, int plane, int limits, int dq) {
   int cost = 0;
   const int(*base_lf_cost_ptr)[DQ_CTXS][LF_BASE_SYMBOLS * 2] =
       plane > 0 ? txb_costs->base_lf_cost_uv : txb_costs->base_lf_cost;
@@ -436,7 +445,7 @@
                                          const LV_MAP_COEFF_COST *txb_costs,
                                          int bwl, TX_CLASS tx_class,
 #if CONFIG_CONTEXT_DERIVATION
-                                         int32_t *tmp_sign,
+                                         const int32_t *tmp_sign,
 #endif  // CONFIG_CONTEXT_DERIVATION
                                          int plane, int limits, int dq) {
   int cost = 0;
@@ -762,16 +771,21 @@
 
 void av1_decide_states_c(const struct tcq_node_t *prev,
                          const struct tcq_rate_t *rd,
-                         const struct prequant_t *pq, int limits, int try_eob,
-                         int64_t rdmult, struct tcq_node_t *decision) {
+                         const struct prequant_t *pq, int n_states, int limits,
+                         int try_eob, int64_t rdmult,
+                         struct tcq_node_t *decision) {
   const int32_t *rate = rd->rate;
   const int32_t *rate_zero = rd->rate_zero;
   const int32_t *rate_eob = rd->rate_eob;
-  int64_t rdCost[2 * TOTALSTATES];
-  int64_t rdCost_zero[TOTALSTATES];
+  int64_t rdCost[2 * TCQ_MAX_STATES];
+  int64_t rdCost_zero[TCQ_MAX_STATES];
   int64_t rdCost_eob[2];
 
-  for (int i = 0; i < TOTALSTATES; i++) {
+  // Init for ASAN
+  memset(rdCost, 0, sizeof(rdCost));
+  memset(rdCost_zero, 0, sizeof(rdCost_zero));
+
+  for (int i = 0; i < n_states; i++) {
     int a0 = tcq_quant(i);
     int a1 = a0 + 2;
     int64_t dist0 = pq->deltaDist[a0];
@@ -782,54 +796,63 @@
   }
   rdCost_eob[0] = RDCOST(rdmult, rate_eob[0], pq->deltaDist[0]);
   rdCost_eob[1] = RDCOST(rdmult, rate_eob[1], pq->deltaDist[2]);
-#if MORESTATES
-  decide_new(rdCost[0], rdCost[1], rdCost_zero[0], rate[0], rate[1],
-             rate_zero[0], pq->absLevel[0], pq->absLevel[2], limits,
-             prev[0].rate, 0, &decision[0], &decision[4]);
-  decide_new(rdCost[2], rdCost[3], rdCost_zero[1], rate[2], rate[3],
-             rate_zero[1], pq->absLevel[0], pq->absLevel[2], limits,
-             prev[1].rate, 1, &decision[4], &decision[0]);
-  decide_new(rdCost[5], rdCost[4], rdCost_zero[2], rate[5], rate[4],
-             rate_zero[2], pq->absLevel[3], pq->absLevel[1], limits,
-             prev[2].rate, 2, &decision[1], &decision[5]);
-  decide_new(rdCost[7], rdCost[6], rdCost_zero[3], rate[7], rate[6],
-             rate_zero[3], pq->absLevel[3], pq->absLevel[1], limits,
-             prev[3].rate, 3, &decision[5], &decision[1]);
-  decide_new(rdCost[8], rdCost[9], rdCost_zero[4], rate[8], rate[9],
-             rate_zero[4], pq->absLevel[0], pq->absLevel[2], limits,
-             prev[4].rate, 4, &decision[6], &decision[2]);
-  decide_new(rdCost[10], rdCost[11], rdCost_zero[5], rate[10], rate[11],
-             rate_zero[5], pq->absLevel[0], pq->absLevel[2], limits,
-             prev[5].rate, 5, &decision[2], &decision[6]);
-  decide_new(rdCost[13], rdCost[12], rdCost_zero[6], rate[13], rate[12],
-             rate_zero[6], pq->absLevel[3], pq->absLevel[1], limits,
-             prev[6].rate, 6, &decision[7], &decision[3]);
-  decide_new(rdCost[15], rdCost[14], rdCost_zero[7], rate[15], rate[14],
-             rate_zero[7], pq->absLevel[3], pq->absLevel[1], limits,
-             prev[7].rate, 7, &decision[3], &decision[7]);
-#else
-  decide_new(rdCost[0], rdCost[1], rdCost_zero[0], rate[0], rate[1],
-             rate_zero[0], pq->absLevel[0], pq->absLevel[2], limits,
-             prev[0].rate, 0, &decision[0], &decision[2]);
-  decide_new(rdCost[2], rdCost[3], rdCost_zero[1], rate[2], rate[3],
-             rate_zero[1], pq->absLevel[0], pq->absLevel[2], limits,
-             prev[1].rate, 1, &decision[2], &decision[0]);
-  decide_new(rdCost[5], rdCost[4], rdCost_zero[2], rate[5], rate[4],
-             rate_zero[2], pq->absLevel[3], pq->absLevel[1], limits,
-             prev[2].rate, 2, &decision[1], &decision[3]);
-  decide_new(rdCost[7], rdCost[6], rdCost_zero[3], rate[7], rate[6],
-             rate_zero[3], pq->absLevel[3], pq->absLevel[1], limits,
-             prev[3].rate, 3, &decision[3], &decision[1]);
-#endif
+  if (n_states == 4) {
+    decide_new(rdCost[0], rdCost[1], rdCost_zero[0], rate[0], rate[1],
+               rate_zero[0], pq->absLevel[0], pq->absLevel[2], limits,
+               prev[0].rate, 0, &decision[0], &decision[2]);
+    decide_new(rdCost[2], rdCost[3], rdCost_zero[1], rate[2], rate[3],
+               rate_zero[1], pq->absLevel[0], pq->absLevel[2], limits,
+               prev[1].rate, 1, &decision[2], &decision[0]);
+    decide_new(rdCost[5], rdCost[4], rdCost_zero[2], rate[5], rate[4],
+               rate_zero[2], pq->absLevel[3], pq->absLevel[1], limits,
+               prev[2].rate, 2, &decision[1], &decision[3]);
+    decide_new(rdCost[7], rdCost[6], rdCost_zero[3], rate[7], rate[6],
+               rate_zero[3], pq->absLevel[3], pq->absLevel[1], limits,
+               prev[3].rate, 3, &decision[3], &decision[1]);
+  } else {  // n_states == 8
+    decide_new(rdCost[0], rdCost[1], rdCost_zero[0], rate[0], rate[1],
+               rate_zero[0], pq->absLevel[0], pq->absLevel[2], limits,
+               prev[0].rate, 0, &decision[0], &decision[4]);
+    decide_new(rdCost[2], rdCost[3], rdCost_zero[1], rate[2], rate[3],
+               rate_zero[1], pq->absLevel[0], pq->absLevel[2], limits,
+               prev[1].rate, 1, &decision[4], &decision[0]);
+    decide_new(rdCost[5], rdCost[4], rdCost_zero[2], rate[5], rate[4],
+               rate_zero[2], pq->absLevel[3], pq->absLevel[1], limits,
+               prev[2].rate, 2, &decision[1], &decision[5]);
+    decide_new(rdCost[7], rdCost[6], rdCost_zero[3], rate[7], rate[6],
+               rate_zero[3], pq->absLevel[3], pq->absLevel[1], limits,
+               prev[3].rate, 3, &decision[5], &decision[1]);
+    decide_new(rdCost[8], rdCost[9], rdCost_zero[4], rate[8], rate[9],
+               rate_zero[4], pq->absLevel[0], pq->absLevel[2], limits,
+               prev[4].rate, 4, &decision[6], &decision[2]);
+    decide_new(rdCost[10], rdCost[11], rdCost_zero[5], rate[10], rate[11],
+               rate_zero[5], pq->absLevel[0], pq->absLevel[2], limits,
+               prev[5].rate, 5, &decision[2], &decision[6]);
+    decide_new(rdCost[13], rdCost[12], rdCost_zero[6], rate[13], rate[12],
+               rate_zero[6], pq->absLevel[3], pq->absLevel[1], limits,
+               prev[6].rate, 6, &decision[7], &decision[3]);
+    decide_new(rdCost[15], rdCost[14], rdCost_zero[7], rate[15], rate[14],
+               rate_zero[7], pq->absLevel[3], pq->absLevel[1], limits,
+               prev[7].rate, 7, &decision[3], &decision[7]);
+  }
   if (try_eob) {
-    const int state0 = next_st[0][0];
-    const int state1 = next_st[0][1];
+    const int state0 = 0;
+    const int state1 = n_states == 8 ? 4 : 2;
     decide_eob(rdCost_eob[0], rdCost_eob[1], rate_eob[0], rate_eob[1],
                pq->absLevel[0], pq->absLevel[2], &decision[state0],
                &decision[state1]);
   }
 }
 
+void av1_decide_states_st4_c(const struct tcq_node_t *prev,
+                             const struct tcq_rate_t *rd,
+                             const struct prequant_t *pq, int n_states,
+                             int limits, int try_eob, int64_t rdmult,
+                             struct tcq_node_t *decision) {
+  av1_decide_states_c(prev, rd, pq, n_states, limits, try_eob, rdmult,
+                      decision);
+}
+
 void av1_pre_quant_c(tran_low_t tqc, struct prequant_t *pqData,
                      const int32_t *quant_ptr, int dqv, int log_scale,
                      int scan_pos) {
@@ -889,7 +912,7 @@
 static int get_coeff_cost(int ci, tran_low_t abs_qc, int sign, int coeff_ctx,
                           int mid_ctx, int dc_sign_ctx,
                           const LV_MAP_COEFF_COST *txb_costs, int bwl,
-                          TX_CLASS tx_class, int32_t *tmp_sign, int plane,
+                          TX_CLASS tx_class, const int32_t *tmp_sign, int plane,
                           int limits, int dq) {
   return get_coeff_cost_general(ci, abs_qc, sign, coeff_ctx, mid_ctx,
                                 dc_sign_ctx, txb_costs, bwl, tx_class,
@@ -899,28 +922,38 @@
                                 plane, limits, dq);
 }
 
-void trellis_first_pos(int scan_pos, int plane, TX_SIZE tx_size,
-                       TX_CLASS tx_class, int32_t *tmp_sign, int sharpness,
-                       tcq_levels_t *tcq_lev,
-                       tcq_node_t trellis[MAX_TRELLIS][TOTALSTATES],
-                       tran_low_t *qcoeff, const int64_t rdmult, int log_scale,
-                       const int16_t *scan, const tran_low_t *tcoeff,
-                       const int32_t *dequant, const int32_t *quant,
-                       const qm_val_t *iqmatrix, const uint16_t *block_eob_rate,
-                       const TXB_CTX *const txb_ctx,
-                       const LV_MAP_COEFF_COST *txb_costs) {
+void trellis_first_pos(const tcq_param_t *p, int scan_pos,
+                       tcq_levels_t *tcq_lev, tcq_node_t *trellis) {
+  int n_states = p->n_states;
+  int n_states_log2 = p->n_states_log2;
+  int plane = p->plane;
+  TX_SIZE tx_size = p->tx_size;
+  TX_CLASS tx_class = p->tx_class;
+  int log_scale = p->log_scale;
+  int sharpness = p->sharpness;
+  int64_t rdmult = p->rdmult;
+  const int16_t *scan = p->scan;
+  const int32_t *tmp_sign = p->tmp_sign;
+  const tran_low_t *qcoeff = p->qcoeff;
+  const tran_low_t *tcoeff = p->tcoeff;
+  const int32_t *quant = p->quant;
+  const int32_t *dequant = p->dequant;
+  const qm_val_t *iqmatrix = p->iqmatrix;
+  const uint16_t *block_eob_rate = p->block_eob_rate;
+  const TXB_CTX *txb_ctx = p->txb_ctx;
+  const LV_MAP_COEFF_COST *txb_costs = p->txb_costs;
   const int bwl = get_txb_bwl(tx_size);
   const int height = get_txb_high(tx_size);
 
   int blk_pos = scan[scan_pos];
-  tcq_node_t *decision = trellis[scan_pos];
+  tcq_node_t *decision = &trellis[scan_pos << n_states_log2];
 
   prequant_t pqData;
   int tempdqv = get_dqv(dequant, scan[scan_pos], iqmatrix);
   av1_pre_quant(tcoeff[blk_pos], &pqData, quant, tempdqv, log_scale, scan_pos);
 
   // init state
-  init_tcq_decision(decision);
+  init_tcq_decision(decision, n_states);
 
   const int row = blk_pos >> bwl;
   const int col = blk_pos - (row << bwl);
@@ -952,8 +985,8 @@
                          ,
                          plane) +
       eob_rate;
-  const int state0 = next_st[0][0];
-  const int state1 = next_st[0][1];
+  const int state0 = 0;
+  const int state1 = (n_states == 4) ? 2 : 4;
   decide(0, pqData.deltaDist[0], pqData.deltaDist[2], rdmult, rate_Q0_a,
          rate_Q0_b, INT32_MAX >> 1, pqData.absLevel[0], pqData.absLevel[2],
          limits, 0, -1, &decision[state0], &decision[state1]);
@@ -970,7 +1003,7 @@
                                   const struct prequant_t *pq,
                                   const struct tcq_coeff_ctx_t *coeff_ctx,
                                   int blk_pos, int bwl, TX_CLASS tx_class,
-                                  int diag_ctx, int eob_rate,
+                                  int diag_ctx, int eob_rate, int n_states,
                                   struct tcq_rate_t *rd) {
   const int plane = 0;
   const int t_sign = 0;
@@ -978,7 +1011,7 @@
   const int dc_sign_ctx = 0;
   const tran_low_t *absLevel = pq->absLevel;
 
-  for (int i = 0; i < TOTALSTATES; i++) {
+  for (int i = 0; i < n_states; i++) {
     int dq = tcq_quant(i);
     int a0 = dq;
     int a1 = a0 + 2;
@@ -1001,17 +1034,28 @@
                                     bwl, tx_class, t_sign, plane);
 }
 
+// Same as above function, but specialized to 4 states in the SIMD version.
+void av1_get_rate_dist_def_luma_st4_c(const struct LV_MAP_COEFF_COST *txb_costs,
+                                      const struct prequant_t *pq,
+                                      const struct tcq_coeff_ctx_t *coeff_ctx,
+                                      int blk_pos, int bwl, TX_CLASS tx_class,
+                                      int diag_ctx, int eob_rate, int n_states,
+                                      struct tcq_rate_t *rd) {
+  av1_get_rate_dist_def_luma_c(txb_costs, pq, coeff_ctx, blk_pos, bwl, tx_class,
+                               diag_ctx, eob_rate, n_states, rd);
+}
+
 void av1_get_rate_dist_def_chroma_c(const struct LV_MAP_COEFF_COST *txb_costs,
                                     const struct prequant_t *pq,
                                     const struct tcq_coeff_ctx_t *coeff_ctx,
                                     int blk_pos, int bwl, TX_CLASS tx_class,
                                     int diag_ctx, int eob_rate, int plane,
-                                    int t_sign, int sign,
+                                    int t_sign, int sign, int n_states,
                                     struct tcq_rate_t *rd) {
   const tran_low_t *absLevel = pq->absLevel;
   const int dc_sign_ctx = 0;
 
-  for (int i = 0; i < TOTALSTATES; i++) {
+  for (int i = 0; i < n_states; i++) {
     int dq = tcq_quant(i);
     int a0 = dq;
     int a1 = a0 + 2;
@@ -1038,28 +1082,28 @@
                                  const struct prequant_t *pq,
                                  const struct tcq_coeff_ctx_t *coeff_ctx,
                                  int blk_pos, int diag_ctx, int eob_rate,
-                                 int dc_sign_ctx, int32_t *tmp_sign, int bwl,
-                                 TX_CLASS tx_class, int coeff_sign,
-                                 struct tcq_rate_t *rd) {
+                                 int dc_sign_ctx, const int32_t *tmp_sign,
+                                 int bwl, TX_CLASS tx_class, int coeff_sign,
+                                 int n_states, struct tcq_rate_t *rd) {
   const tran_low_t *absLevel = pq->absLevel;
-  uint8_t base_ctx[TOTALSTATES];
-  uint8_t mid_ctx[TOTALSTATES];
+  uint8_t base_ctx;
+  uint8_t mid_ctx;
   int t_sign = tmp_sign[blk_pos];
   int plane = 0;
 
-  for (int i = 0; i < TOTALSTATES; i++) {
+  for (int i = 0; i < n_states; i++) {
     int dq = tcq_quant(i);
     int a0 = dq;
     int a1 = a0 + 2;
-    base_ctx[i] = (coeff_ctx->coef[i] & 15) + diag_ctx;
-    mid_ctx[i] = coeff_ctx->coef[i] >> 4;
-    int cost0 = get_coeff_cost(blk_pos, absLevel[a0], coeff_sign, base_ctx[i],
-                               mid_ctx[i], dc_sign_ctx, txb_costs, bwl,
-                               tx_class, tmp_sign, plane, 1, dq);
-    int cost1 = get_coeff_cost(blk_pos, absLevel[a1], coeff_sign, base_ctx[i],
-                               mid_ctx[i], dc_sign_ctx, txb_costs, bwl,
-                               tx_class, tmp_sign, plane, 1, dq);
-    rd->rate_zero[i] = txb_costs->base_lf_cost[base_ctx[i]][dq][0];
+    base_ctx = (coeff_ctx->coef[i] & 15) + diag_ctx;
+    mid_ctx = coeff_ctx->coef[i] >> 4;
+    int cost0 = get_coeff_cost(blk_pos, absLevel[a0], coeff_sign, base_ctx,
+                               mid_ctx, dc_sign_ctx, txb_costs, bwl, tx_class,
+                               tmp_sign, plane, 1, dq);
+    int cost1 = get_coeff_cost(blk_pos, absLevel[a1], coeff_sign, base_ctx,
+                               mid_ctx, dc_sign_ctx, txb_costs, bwl, tx_class,
+                               tmp_sign, plane, 1, dq);
+    rd->rate_zero[i] = txb_costs->base_lf_cost[base_ctx][dq][0];
     rd->rate[2 * i] = cost0;
     rd->rate[2 * i + 1] = cost1;
   }
@@ -1073,31 +1117,45 @@
                                     bwl, tx_class, t_sign, plane);
 }
 
+// Same as above function, but specialized to 4 states in the SIMD version.
+void av1_get_rate_dist_lf_luma_st4_c(const struct LV_MAP_COEFF_COST *txb_costs,
+                                     const struct prequant_t *pq,
+                                     const struct tcq_coeff_ctx_t *coeff_ctx,
+                                     int blk_pos, int diag_ctx, int eob_rate,
+                                     int dc_sign_ctx, const int32_t *tmp_sign,
+                                     int bwl, TX_CLASS tx_class, int coeff_sign,
+                                     int n_states, struct tcq_rate_t *rd) {
+  av1_get_rate_dist_lf_luma_c(txb_costs, pq, coeff_ctx, blk_pos, diag_ctx,
+                              eob_rate, dc_sign_ctx, tmp_sign, bwl, tx_class,
+                              coeff_sign, n_states, rd);
+}
+
 void av1_get_rate_dist_lf_chroma_c(const struct LV_MAP_COEFF_COST *txb_costs,
                                    const struct prequant_t *pq,
                                    const struct tcq_coeff_ctx_t *coeff_ctx,
                                    int blk_pos, int diag_ctx, int eob_rate,
-                                   int dc_sign_ctx, int32_t *tmp_sign, int bwl,
-                                   TX_CLASS tx_class, int plane, int coeff_sign,
+                                   int dc_sign_ctx, const int32_t *tmp_sign,
+                                   int bwl, TX_CLASS tx_class, int plane,
+                                   int coeff_sign, int n_states,
                                    struct tcq_rate_t *rd) {
   const tran_low_t *absLevel = pq->absLevel;
-  uint8_t base_ctx[TOTALSTATES];
-  uint8_t mid_ctx[TOTALSTATES];
+  uint8_t base_ctx;
+  uint8_t mid_ctx;
   int t_sign = tmp_sign[blk_pos];
 
-  for (int i = 0; i < TOTALSTATES; i++) {
+  for (int i = 0; i < n_states; i++) {
     int dq = tcq_quant(i);
     int a0 = dq;
     int a1 = a0 + 2;
-    base_ctx[i] = (coeff_ctx->coef[i] & 15) + diag_ctx;
-    mid_ctx[i] = coeff_ctx->coef[i] >> 4;
-    int cost0 = get_coeff_cost(blk_pos, absLevel[a0], coeff_sign, base_ctx[i],
-                               mid_ctx[i], dc_sign_ctx, txb_costs, bwl,
-                               tx_class, tmp_sign, plane, 1, dq);
-    int cost1 = get_coeff_cost(blk_pos, absLevel[a1], coeff_sign, base_ctx[i],
-                               mid_ctx[i], dc_sign_ctx, txb_costs, bwl,
-                               tx_class, tmp_sign, plane, 1, dq);
-    rd->rate_zero[i] = txb_costs->base_lf_cost_uv[base_ctx[i]][dq][0];
+    base_ctx = (coeff_ctx->coef[i] & 15) + diag_ctx;
+    mid_ctx = coeff_ctx->coef[i] >> 4;
+    int cost0 = get_coeff_cost(blk_pos, absLevel[a0], coeff_sign, base_ctx,
+                               mid_ctx, dc_sign_ctx, txb_costs, bwl, tx_class,
+                               tmp_sign, plane, 1, dq);
+    int cost1 = get_coeff_cost(blk_pos, absLevel[a1], coeff_sign, base_ctx,
+                               mid_ctx, dc_sign_ctx, txb_costs, bwl, tx_class,
+                               tmp_sign, plane, 1, dq);
+    rd->rate_zero[i] = txb_costs->base_lf_cost_uv[base_ctx][dq][0];
     rd->rate[2 * i] = cost0;
     rd->rate[2 * i + 1] = cost1;
   }
@@ -1124,10 +1182,10 @@
   }
 }
 
-void av1_update_states_c(tcq_node_t *decision, int scan_idx,
+void av1_update_states_c(tcq_node_t *decision, int scan_idx, int n_states,
                          const struct tcq_ctx_t *cur_ctx,
                          struct tcq_ctx_t *nxt_ctx) {
-  for (int i = 0; i < TOTALSTATES; i++) {
+  for (int i = 0; i < n_states; i++) {
     int prevId = decision[i].prevId;
     int absLevel = decision[i].absLevel;
     if (prevId >= 0) {
@@ -1143,8 +1201,9 @@
 
 static void update_levels_diagonal(tcq_levels_t *tcq_lev, const int16_t *scan,
                                    int bufsize, int bwl, int scan_hi,
-                                   int scan_lo, const tcq_ctx_t *tcq_ctx) {
-  for (int i = 0; i < TOTALSTATES; i++) {
+                                   int scan_lo, int n_states,
+                                   const tcq_ctx_t *tcq_ctx) {
+  for (int i = 0; i < n_states; i++) {
     int orig_id = tcq_ctx[i].orig_id;
     uint8_t *cur_lev = tcq_levels_cur(tcq_lev, i);
     uint8_t *prev_lev = tcq_levels_prev(tcq_lev, orig_id);
@@ -1161,35 +1220,46 @@
 }
 
 // Handle trellis default region for Luma, TX_CLASS_2D blocks.
-void trellis_loop_diagonal(
-    int scan_hi, int scan_lo, int plane, TX_SIZE tx_size, TX_CLASS tx_class,
-    int32_t *tmp_sign, int sharpness, tcq_levels_t *tcq_lev,
-    tcq_ctx_t tcq_ctx[TOTALSTATES],
-    tcq_node_t trellis[MAX_TRELLIS][TOTALSTATES], tran_low_t *qcoeff,
-    const int64_t rdmult, int log_scale, const int16_t *scan,
-    const tran_low_t *tcoeff, const int32_t *dequant, const int32_t *quant,
-    const qm_val_t *iqmatrix, const uint16_t *block_eob_rate,
-    const TXB_CTX *const txb_ctx, const LV_MAP_COEFF_COST *txb_costs) {
+// TCQ 4-state
+static void trellis_loop_diagonal_st4(const tcq_param_t *p, int scan_hi,
+                                      int scan_lo, tcq_levels_t *tcq_lev,
+                                      tcq_ctx_t tcq_ctx[2 * TCQ_MAX_STATES],
+                                      tcq_node_t *trellis) {
+  int plane = p->plane;
+  TX_SIZE tx_size = p->tx_size;
+  TX_CLASS tx_class = p->tx_class;
+  int log_scale = p->log_scale;
+  int try_eob = p->sharpness == 0;
+  int64_t rdmult = p->rdmult;
+  const int16_t *scan = p->scan;
+  const tran_low_t *tcoeff = p->tcoeff;
+  const int32_t *quant = p->quant;
+  const int32_t *dequant = p->dequant;
+  const qm_val_t *iqmatrix = p->iqmatrix;
+  const uint16_t *block_eob_rate = p->block_eob_rate;
+  const LV_MAP_COEFF_COST *txb_costs = p->txb_costs;
   const int bwl = get_txb_bwl(tx_size);
   const int height = get_txb_high(tx_size);
   const int pos0 = scan[scan_hi];
   const int diag_ctx = get_nz_map_ctx_from_stats(0, pos0, bwl, TX_CLASS_2D, 0);
+
+  assert(p->n_states == 4);
   assert(plane == 0);
   assert(tx_class == TX_CLASS_2D);
   (void)plane;
-  (void)tmp_sign;
-  (void)qcoeff;
-  (void)txb_ctx;
+
+  const int n_st = 4;
+  const int n_st_log2 = 2;
 
   // Precompute base and mid ctx values, as they are independent across
   // the diagonal pass.
   tcq_levels_swap(tcq_lev);
 
   int i_ctx = scan_hi & 1;
-  tcq_ctx_t *cur_ctx = &tcq_ctx[i_ctx ? TOTALSTATES : 0];
-  tcq_ctx_t *nxt_ctx = &tcq_ctx[i_ctx ? 0 : TOTALSTATES];
+  tcq_ctx_t *cur_ctx = &tcq_ctx[i_ctx ? TCQ_MAX_STATES : 0];
+  tcq_ctx_t *nxt_ctx = &tcq_ctx[i_ctx ? 0 : TCQ_MAX_STATES];
 
-  for (int i = 0; i < TOTALSTATES; i++) {
+  for (int i = 0; i < n_st; i++) {
     uint8_t *prev_levels = tcq_levels_prev(tcq_lev, i);
     av1_calc_diag_ctx(scan_hi, scan_lo, bwl, prev_levels, scan, cur_ctx[i].ctx);
     cur_ctx[i].orig_id = i;
@@ -1197,8 +1267,8 @@
 
   for (int scan_pos = scan_hi; scan_pos >= scan_lo; scan_pos--) {
     const int blk_pos = scan[scan_pos];
-    tcq_node_t *decision = trellis[scan_pos];
-    tcq_node_t *prev_decision = trellis[scan_pos + 1];
+    tcq_node_t *decision = &trellis[scan_pos << n_st_log2];
+    tcq_node_t *prev_decision = &decision[n_st];
 
     prequant_t pqData;
     int tempdqv = get_dqv(dequant, scan[scan_pos], iqmatrix);
@@ -1206,12 +1276,12 @@
                   scan_pos);
 
     // init state
-    init_tcq_decision(decision);
+    init_tcq_decision(decision, n_st);
     const int limits = 0;
 
     // calculate rate distortion
     tcq_coeff_ctx_t coeff_ctx;
-    for (int i = 0; i < TOTALSTATES; i++) {
+    for (int i = 0; i < n_st; i++) {
       coeff_ctx.coef[i] = cur_ctx[i].ctx[scan_pos - scan_lo];
     }
     int eob_ctx = get_lower_levels_ctx_eob(bwl, height, scan_pos);
@@ -1219,14 +1289,13 @@
     int eob_rate = block_eob_rate[scan_pos];
 
     tcq_rate_t rd;
-    av1_get_rate_dist_def_luma(txb_costs, &pqData, &coeff_ctx, blk_pos, bwl,
-                               tx_class, diag_ctx, eob_rate, &rd);
+    av1_get_rate_dist_def_luma_st4(txb_costs, &pqData, &coeff_ctx, blk_pos, bwl,
+                                   tx_class, diag_ctx, eob_rate, n_st, &rd);
 
-    int try_eob = sharpness == 0;
-    av1_decide_states(prev_decision, &rd, &pqData, limits, try_eob, rdmult,
-                      decision);
+    av1_decide_states_st4(prev_decision, &rd, &pqData, n_st, limits, try_eob,
+                          rdmult, decision);
 
-    av1_update_states(decision, scan_pos - scan_lo, cur_ctx, nxt_ctx);
+    av1_update_states(decision, scan_pos - scan_lo, n_st, cur_ctx, nxt_ctx);
 
     // Swap cur/nxt context.
     tcq_ctx_t *tmp = cur_ctx;
@@ -1235,7 +1304,94 @@
   }
 
   update_levels_diagonal(tcq_lev, scan, tcq_lev->bufsize, bwl, scan_hi, scan_lo,
-                         cur_ctx);
+                         n_st, cur_ctx);
+}
+
+// TCQ 8-state
+static void trellis_loop_diagonal_st8(const tcq_param_t *p, int scan_hi,
+                                      int scan_lo, tcq_levels_t *tcq_lev,
+                                      tcq_ctx_t tcq_ctx[2 * TCQ_MAX_STATES],
+                                      tcq_node_t *trellis) {
+  int plane = p->plane;
+  TX_SIZE tx_size = p->tx_size;
+  TX_CLASS tx_class = p->tx_class;
+  int log_scale = p->log_scale;
+  int try_eob = p->sharpness == 0;
+  int64_t rdmult = p->rdmult;
+  const int16_t *scan = p->scan;
+  const tran_low_t *tcoeff = p->tcoeff;
+  const int32_t *quant = p->quant;
+  const int32_t *dequant = p->dequant;
+  const qm_val_t *iqmatrix = p->iqmatrix;
+  const uint16_t *block_eob_rate = p->block_eob_rate;
+  const LV_MAP_COEFF_COST *txb_costs = p->txb_costs;
+  const int bwl = get_txb_bwl(tx_size);
+  const int height = get_txb_high(tx_size);
+  const int pos0 = scan[scan_hi];
+  const int diag_ctx = get_nz_map_ctx_from_stats(0, pos0, bwl, TX_CLASS_2D, 0);
+
+  assert(p->n_states == 8);
+  assert(plane == 0);
+  assert(tx_class == TX_CLASS_2D);
+  (void)plane;
+
+  const int n_st = 8;
+  const int n_st_log2 = 3;
+
+  // Precompute base and mid ctx values, as they are independent across
+  // the diagonal pass.
+  tcq_levels_swap(tcq_lev);
+
+  int i_ctx = scan_hi & 1;
+  tcq_ctx_t *cur_ctx = &tcq_ctx[i_ctx ? TCQ_MAX_STATES : 0];
+  tcq_ctx_t *nxt_ctx = &tcq_ctx[i_ctx ? 0 : TCQ_MAX_STATES];
+
+  for (int i = 0; i < n_st; i++) {
+    uint8_t *prev_levels = tcq_levels_prev(tcq_lev, i);
+    av1_calc_diag_ctx(scan_hi, scan_lo, bwl, prev_levels, scan, cur_ctx[i].ctx);
+    cur_ctx[i].orig_id = i;
+  }
+
+  for (int scan_pos = scan_hi; scan_pos >= scan_lo; scan_pos--) {
+    const int blk_pos = scan[scan_pos];
+    tcq_node_t *decision = &trellis[scan_pos << n_st_log2];
+    tcq_node_t *prev_decision = &decision[n_st];
+
+    prequant_t pqData;
+    int tempdqv = get_dqv(dequant, scan[scan_pos], iqmatrix);
+    av1_pre_quant(tcoeff[blk_pos], &pqData, quant, tempdqv, log_scale,
+                  scan_pos);
+
+    // init state
+    init_tcq_decision(decision, n_st);
+    const int limits = 0;
+
+    // calculate rate distortion
+    tcq_coeff_ctx_t coeff_ctx;
+    for (int i = 0; i < n_st; i++) {
+      coeff_ctx.coef[i] = cur_ctx[i].ctx[scan_pos - scan_lo];
+    }
+    int eob_ctx = get_lower_levels_ctx_eob(bwl, height, scan_pos);
+    coeff_ctx.coef_eob = eob_ctx;
+    int eob_rate = block_eob_rate[scan_pos];
+
+    tcq_rate_t rd;
+    av1_get_rate_dist_def_luma(txb_costs, &pqData, &coeff_ctx, blk_pos, bwl,
+                               tx_class, diag_ctx, eob_rate, n_st, &rd);
+
+    av1_decide_states(prev_decision, &rd, &pqData, n_st, limits, try_eob,
+                      rdmult, decision);
+
+    av1_update_states(decision, scan_pos - scan_lo, n_st, cur_ctx, nxt_ctx);
+
+    // Swap cur/nxt context.
+    tcq_ctx_t *tmp = cur_ctx;
+    cur_ctx = nxt_ctx;
+    nxt_ctx = tmp;
+  }
+
+  update_levels_diagonal(tcq_lev, scan, tcq_lev->bufsize, bwl, scan_hi, scan_lo,
+                         n_st, cur_ctx);
 }
 
 void av1_init_lf_ctx_c(const uint8_t *lev, int scan_hi, int bwl,
@@ -1261,8 +1417,8 @@
 // Initialize LF neighbor context.
 // The lf_ctx->last[] array tracks the last N previous coeffs (LIFO),
 // and used to calculate coeff neighbor contexts.
-void av1_calc_lf_ctx_c(const struct tcq_lf_ctx_t *lf_ctx, int scan_pos,
-                       struct tcq_coeff_ctx_t *coeff_ctx) {
+void av1_calc_lf_ctx_st4_c(const struct tcq_lf_ctx_t *lf_ctx, int scan_pos,
+                           struct tcq_coeff_ctx_t *coeff_ctx) {
   static const int8_t kMaxCtx[16] = { 8, 6, 6, 4, 4, 4, 4, 4,
                                       4, 4, 4, 4, 4, 4, 4, 4 };
   static const int8_t kScanDiag[MAX_LF_SCAN] = { 0, 1, 1, 2, 2, 2, 3, 3, 3, 3 };
@@ -1272,8 +1428,9 @@
     { 0, 0, 3, 3, 0, 0, 1, 3, 1, 0, 0 },  // diag 2
     { 0, 0, 0, 3, 3, 0, 0, 0, 1, 3, 1 },  // diag 3
   };
+  int n_states = 4;
 
-  for (int st = 0; st < TOTALSTATES; st++) {
+  for (int st = 0; st < n_states; st++) {
     int diag = kScanDiag[scan_pos];
     int base = 0;
     int mid = 0;
@@ -1292,12 +1449,44 @@
   }
 }
 
-void av1_update_lf_ctx_c(const struct tcq_node_t *decision,
-                         struct tcq_lf_ctx_t *lf_ctx) {
-  tcq_lf_ctx_t save[TOTALSTATES];
-  memcpy(save, lf_ctx, sizeof(tcq_lf_ctx_t) * TOTALSTATES);
+void av1_calc_lf_ctx_st8_c(const struct tcq_lf_ctx_t *lf_ctx, int scan_pos,
+                           struct tcq_coeff_ctx_t *coeff_ctx) {
+  static const int8_t kMaxCtx[16] = { 8, 6, 6, 4, 4, 4, 4, 4,
+                                      4, 4, 4, 4, 4, 4, 4, 4 };
+  static const int8_t kScanDiag[MAX_LF_SCAN] = { 0, 1, 1, 2, 2, 2, 3, 3, 3, 3 };
+  static const int8_t kNbrMask[4][11] = {
+    { 3, 3, 1, 3, 1, 0, 0, 0, 0, 0, 0 },  // diag 0
+    { 0, 3, 3, 0, 1, 3, 1, 0, 0, 0, 0 },  // diag 1
+    { 0, 0, 3, 3, 0, 0, 1, 3, 1, 0, 0 },  // diag 2
+    { 0, 0, 0, 3, 3, 0, 0, 0, 1, 3, 1 },  // diag 3
+  };
+  int n_states = 8;
 
-  for (int st = 0; st < TOTALSTATES; st++) {
+  for (int st = 0; st < n_states; st++) {
+    int diag = kScanDiag[scan_pos];
+    int base = 0;
+    int mid = 0;
+    for (int i = 0; i < 11; i++) {
+      int mask = kNbrMask[diag][i];
+      if (mask) {
+        base += AOMMIN(lf_ctx[st].last[i], 5);
+        if (mask >> 1) {
+          mid += AOMMIN(lf_ctx[st].last[i], MAX_VAL_BR_CTX);
+        }
+      }
+    }
+    int base_ctx = AOMMIN((base + 1) >> 1, kMaxCtx[scan_pos]);
+    int mid_ctx = AOMMIN((mid + 1) >> 1, 6) + ((scan_pos == 0) ? 0 : 7);
+    coeff_ctx->coef[st] = base_ctx + (mid_ctx << 4);
+  }
+}
+
+void av1_update_lf_ctx_c(const struct tcq_node_t *decision, int n_states,
+                         struct tcq_lf_ctx_t *lf_ctx) {
+  tcq_lf_ctx_t save[TCQ_MAX_STATES];
+  memcpy(save, lf_ctx, sizeof(tcq_lf_ctx_t) * n_states);
+
+  for (int st = 0; st < n_states; st++) {
     int absLevel = decision[st].absLevel;
     int prevId = decision[st].prevId;
     int new_eob = prevId < 0;
@@ -1313,25 +1502,33 @@
 }
 
 // Handle trellis Low-freq (LF) region for Luma, TX_CLASS_2D blocks.
-void trellis_loop_lf(int scan_hi, int scan_lo, int plane, TX_SIZE tx_size,
-                     TX_CLASS tx_class, int32_t *tmp_sign, int sharpness,
-                     tcq_levels_t *tcq_lev,
-                     tcq_node_t trellis[MAX_TRELLIS][TOTALSTATES],
-                     tran_low_t *qcoeff, const int64_t rdmult, int log_scale,
-                     const int16_t *scan, const tran_low_t *tcoeff,
-                     const int32_t *dequant, const int32_t *quant,
-                     const qm_val_t *iqmatrix, const uint16_t *block_eob_rate,
-                     const TXB_CTX *const txb_ctx,
-                     const LV_MAP_COEFF_COST *txb_costs) {
+// TCQ 4-state
+static void trellis_loop_lf_st4(const tcq_param_t *p, int scan_hi, int scan_lo,
+                                tcq_levels_t *tcq_lev, tcq_node_t *trellis) {
+  TX_SIZE tx_size = p->tx_size;
+  TX_CLASS tx_class = p->tx_class;
+  int log_scale = p->log_scale;
+  int try_eob = p->sharpness == 0;
+  int64_t rdmult = p->rdmult;
+  const int16_t *scan = p->scan;
+  const int32_t *tmp_sign = p->tmp_sign;
+  const tran_low_t *tcoeff = p->tcoeff;
+  const int32_t *quant = p->quant;
+  const int32_t *dequant = p->dequant;
+  const qm_val_t *iqmatrix = p->iqmatrix;
+  const uint16_t *block_eob_rate = p->block_eob_rate;
+  const TXB_CTX *txb_ctx = p->txb_ctx;
+  const LV_MAP_COEFF_COST *txb_costs = p->txb_costs;
   const int bwl = get_txb_bwl(tx_size);
   const int height = get_txb_high(tx_size);
-  assert(plane == 0);
+  assert(p->plane == 0);
   assert(tx_class == TX_CLASS_2D);
-  (void)plane;
-  (void)qcoeff;
 
-  tcq_lf_ctx_t lf_ctx[TOTALSTATES];
-  for (int i = 0; i < TOTALSTATES; i++) {
+  const int n_st = 4;
+  const int n_st_log2 = 2;
+
+  tcq_lf_ctx_t lf_ctx[TCQ_MAX_STATES];
+  for (int i = 0; i < n_st; i++) {
     uint8_t *lev = tcq_levels_cur(tcq_lev, i);
     av1_init_lf_ctx(lev, scan_hi, bwl, &lf_ctx[i]);
   }
@@ -1339,8 +1536,8 @@
   for (int scan_pos = scan_hi; scan_pos >= scan_lo; scan_pos--) {
     int blk_pos = scan[scan_pos];
 
-    tcq_node_t *decision = trellis[scan_pos];
-    tcq_node_t *prd = trellis[scan_pos + 1];
+    tcq_node_t *decision = &trellis[scan_pos << n_st_log2];
+    tcq_node_t *prd = &decision[n_st];
 
     prequant_t pqData;
     int tempdqv = get_dqv(dequant, scan[scan_pos], iqmatrix);
@@ -1348,50 +1545,134 @@
                   scan_pos);
 
     // init state
-    init_tcq_decision(decision);
+    init_tcq_decision(decision, n_st);
     const int coeff_sign = tcoeff[blk_pos] < 0;
     const int limits = 1;  // Always in LF region.
 
     // calculate contexts
     tcq_coeff_ctx_t coeff_ctx;
     int diag_ctx = get_nz_map_ctx_from_stats_lf(0, blk_pos, bwl, tx_class);
-    av1_calc_lf_ctx(lf_ctx, scan_pos, &coeff_ctx);
     int eob_ctx = get_lower_levels_ctx_eob(bwl, height, scan_pos);
     int eob_rate = block_eob_rate[scan_pos];
     coeff_ctx.coef_eob = eob_ctx;
+    av1_calc_lf_ctx_st4(lf_ctx, scan_pos, &coeff_ctx);
+
+    // calculate rate distortion
+    tcq_rate_t rd;
+    av1_get_rate_dist_lf_luma_st4(
+        txb_costs, &pqData, &coeff_ctx, blk_pos, diag_ctx, eob_rate,
+        txb_ctx->dc_sign_ctx, tmp_sign, bwl, tx_class, coeff_sign, n_st, &rd);
+
+    av1_decide_states_st4(prd, &rd, &pqData, n_st, limits, try_eob, rdmult,
+                          decision);
+
+    av1_update_lf_ctx(decision, n_st, lf_ctx);
+  }
+}
+
+// TCQ 8-state
+static void trellis_loop_lf_st8(const tcq_param_t *p, int scan_hi, int scan_lo,
+                                tcq_levels_t *tcq_lev, tcq_node_t *trellis) {
+  TX_SIZE tx_size = p->tx_size;
+  TX_CLASS tx_class = p->tx_class;
+  int log_scale = p->log_scale;
+  int try_eob = p->sharpness == 0;
+  int64_t rdmult = p->rdmult;
+  const int16_t *scan = p->scan;
+  const int32_t *tmp_sign = p->tmp_sign;
+  const tran_low_t *tcoeff = p->tcoeff;
+  const int32_t *quant = p->quant;
+  const int32_t *dequant = p->dequant;
+  const qm_val_t *iqmatrix = p->iqmatrix;
+  const uint16_t *block_eob_rate = p->block_eob_rate;
+  const TXB_CTX *txb_ctx = p->txb_ctx;
+  const LV_MAP_COEFF_COST *txb_costs = p->txb_costs;
+  const int bwl = get_txb_bwl(tx_size);
+  const int height = get_txb_high(tx_size);
+  assert(p->plane == 0);
+  assert(tx_class == TX_CLASS_2D);
+
+  const int n_st = 8;
+  const int n_st_log2 = 3;
+
+  tcq_lf_ctx_t lf_ctx[TCQ_MAX_STATES];
+  for (int i = 0; i < n_st; i++) {
+    uint8_t *lev = tcq_levels_cur(tcq_lev, i);
+    av1_init_lf_ctx(lev, scan_hi, bwl, &lf_ctx[i]);
+  }
+
+  for (int scan_pos = scan_hi; scan_pos >= scan_lo; scan_pos--) {
+    int blk_pos = scan[scan_pos];
+
+    tcq_node_t *decision = &trellis[scan_pos << n_st_log2];
+    tcq_node_t *prd = &decision[n_st];
+
+    prequant_t pqData;
+    int tempdqv = get_dqv(dequant, scan[scan_pos], iqmatrix);
+    av1_pre_quant(tcoeff[blk_pos], &pqData, quant, tempdqv, log_scale,
+                  scan_pos);
+
+    // init state
+    init_tcq_decision(decision, n_st);
+    const int coeff_sign = tcoeff[blk_pos] < 0;
+    const int limits = 1;  // Always in LF region.
+
+    // calculate contexts
+    tcq_coeff_ctx_t coeff_ctx;
+    int diag_ctx = get_nz_map_ctx_from_stats_lf(0, blk_pos, bwl, tx_class);
+    int eob_ctx = get_lower_levels_ctx_eob(bwl, height, scan_pos);
+    int eob_rate = block_eob_rate[scan_pos];
+    coeff_ctx.coef_eob = eob_ctx;
+    av1_calc_lf_ctx_st8(lf_ctx, scan_pos, &coeff_ctx);
 
     // calculate rate distortion
     tcq_rate_t rd;
     av1_get_rate_dist_lf_luma(txb_costs, &pqData, &coeff_ctx, blk_pos, diag_ctx,
                               eob_rate, txb_ctx->dc_sign_ctx, tmp_sign, bwl,
-                              tx_class, coeff_sign, &rd);
+                              tx_class, coeff_sign, n_st, &rd);
 
-    int try_eob = sharpness == 0;
-    av1_decide_states(prd, &rd, &pqData, limits, try_eob, rdmult, decision);
+    av1_decide_states(prd, &rd, &pqData, n_st, limits, try_eob, rdmult,
+                      decision);
 
-    av1_update_lf_ctx(decision, lf_ctx);
+    av1_update_lf_ctx(decision, n_st, lf_ctx);
   }
 }
 
-void trellis_loop(int first_scan_pos, int scan_hi, int scan_lo, int plane,
-                  TX_SIZE tx_size, TX_CLASS tx_class, int32_t *tmp_sign,
-                  int sharpness, tcq_levels_t *tcq_lev,
-                  tcq_node_t trellis[MAX_TRELLIS][TOTALSTATES],
-                  tran_low_t *qcoeff, const int64_t rdmult, int log_scale,
-                  const int16_t *scan, const tran_low_t *tcoeff,
-                  const int32_t *dequant, const int32_t *quant,
-                  const qm_val_t *iqmatrix, const uint16_t *block_eob_rate,
-                  const TXB_CTX *const txb_ctx,
-                  const LV_MAP_COEFF_COST *txb_costs) {
+void trellis_loop(const tcq_param_t *p, int first_scan_pos, int scan_hi,
+                  int scan_lo, tcq_levels_t *tcq_lev, tcq_node_t *trellis) {
+  int n_states = p->n_states;
+  int n_states_log2 = p->n_states_log2;
+  int plane = p->plane;
+  TX_SIZE tx_size = p->tx_size;
+  TX_CLASS tx_class = p->tx_class;
+  int log_scale = p->log_scale;
+  int sharpness = p->sharpness;
+  int try_eob = sharpness == 0;
+  int64_t rdmult = p->rdmult;
+  const int16_t *scan = p->scan;
+  const int32_t *tmp_sign = p->tmp_sign;
+  const tran_low_t *tcoeff = p->tcoeff;
+  const int32_t *quant = p->quant;
+  const int32_t *dequant = p->dequant;
+  const qm_val_t *iqmatrix = p->iqmatrix;
+  const uint16_t *block_eob_rate = p->block_eob_rate;
+  const TXB_CTX *txb_ctx = p->txb_ctx;
+  const LV_MAP_COEFF_COST *txb_costs = p->txb_costs;
   const int bwl = get_txb_bwl(tx_size);
   const int height = get_txb_high(tx_size);
-  (void)qcoeff;
+  DecideStateFnc f_decide_states =
+      n_states == 4 ? av1_decide_states_st4 : av1_decide_states;
+  GetDefLumaRateDistFnc f_get_rate_dist_def_luma =
+      n_states == 4 ? av1_get_rate_dist_def_luma_st4
+                    : av1_get_rate_dist_def_luma;
+  GetLfLumaRateDistFnc f_get_rate_dist_lf_luma =
+      n_states == 4 ? av1_get_rate_dist_lf_luma_st4 : av1_get_rate_dist_lf_luma;
 
   for (int scan_pos = scan_hi; scan_pos >= scan_lo; scan_pos--) {
     tcq_levels_swap(tcq_lev);
-    uint8_t *levels[TOTALSTATES];
-    uint8_t *prev_levels[TOTALSTATES];
-    for (int i = 0; i < TOTALSTATES; i++) {
+    uint8_t *levels[TCQ_MAX_STATES];
+    uint8_t *prev_levels[TCQ_MAX_STATES];
+    for (int i = 0; i < n_states; i++) {
       prev_levels[i] = tcq_levels_prev(tcq_lev, i);
       levels[i] = tcq_levels_cur(tcq_lev, i);
     }
@@ -1401,8 +1682,8 @@
     int col = blk_pos - (row << bwl);
     int limits = get_lf_limits(row, col, tx_class, plane);
 
-    tcq_node_t *decision = trellis[scan_pos];
-    tcq_node_t *prd = trellis[scan_pos + 1];
+    tcq_node_t *decision = &trellis[scan_pos << n_states_log2];
+    tcq_node_t *prd = &decision[n_states];
 
     prequant_t pqData;
     int tempdqv = get_dqv(dequant, scan[scan_pos], iqmatrix);
@@ -1410,77 +1691,78 @@
                   scan_pos);
 
     // init state
-    init_tcq_decision(decision);
+    init_tcq_decision(decision, n_states);
     const int coeff_sign = tcoeff[blk_pos] < 0;
 
     // calculate contexts
-    int diag_ctx =
-        (limits && plane == 0)
-            ? get_nz_map_ctx_from_stats_lf(0, blk_pos, bwl, tx_class)
-        : plane == 0 ? get_nz_map_ctx_from_stats(0, blk_pos, bwl, tx_class, 0)
-        : limits
-            ? get_nz_map_ctx_from_stats_lf_chroma(0, tx_class, plane)
-            : get_nz_map_ctx_from_stats_chroma(0, blk_pos, tx_class, plane);
 
     tcq_coeff_ctx_t coeff_ctx;
-    if (limits) {
-      for (int i = 0; i < TOTALSTATES; i++) {
-        int base_ctx = plane
-                           ? get_lower_levels_lf_ctx_chroma(
-                                 prev_levels[i], blk_pos, bwl, tx_class, plane)
-                           : get_lower_levels_lf_ctx(prev_levels[i], blk_pos,
-                                                     bwl, tx_class);
-        int br_ctx =
-            plane ? get_br_lf_ctx_chroma(prev_levels[i], blk_pos, bwl, tx_class)
-                  : get_br_lf_ctx(prev_levels[i], blk_pos, bwl, tx_class);
-        coeff_ctx.coef[i] = base_ctx - diag_ctx + (br_ctx << 4);
-      }
-    } else {
-      for (int i = 0; i < TOTALSTATES; i++) {
-        int base_ctx =
-            plane ? get_lower_levels_ctx_chroma(prev_levels[i], blk_pos, bwl,
-                                                tx_class, plane)
-                  : get_lower_levels_ctx(prev_levels[i], blk_pos, bwl, tx_class
-#if CONFIG_CHROMA_TX_COEFF_CODING
-                                         ,
-                                         plane
-#endif
-                    );
-        int br_ctx =
-            plane ? get_br_ctx_chroma(prev_levels[i], blk_pos, bwl, tx_class)
-                  : get_br_ctx(prev_levels[i], blk_pos, bwl, tx_class);
-        coeff_ctx.coef[i] = base_ctx - diag_ctx + (br_ctx << 4);
-      }
-    }
     int eob_ctx = get_lower_levels_ctx_eob(bwl, height, scan_pos);
     int eob_rate = block_eob_rate[scan_pos];
     coeff_ctx.coef_eob = eob_ctx;
 
-    // calculate rate distortion
     tcq_rate_t rd;
-    if (limits && plane == 0) {
-      av1_get_rate_dist_lf_luma(txb_costs, &pqData, &coeff_ctx, blk_pos,
+
+    // Calculate contexts and rate distortion
+    if (limits) {
+      if (plane == 0) {
+        int diag_ctx = get_nz_map_ctx_from_stats_lf(0, blk_pos, bwl, tx_class);
+        for (int i = 0; i < n_states; i++) {
+          int base_ctx =
+              get_lower_levels_lf_ctx(prev_levels[i], blk_pos, bwl, tx_class);
+          int br_ctx = get_br_lf_ctx(prev_levels[i], blk_pos, bwl, tx_class);
+          coeff_ctx.coef[i] = base_ctx - diag_ctx + (br_ctx << 4);
+        }
+        f_get_rate_dist_lf_luma(txb_costs, &pqData, &coeff_ctx, blk_pos,
                                 diag_ctx, eob_rate, txb_ctx->dc_sign_ctx,
-                                tmp_sign, bwl, tx_class, coeff_sign, &rd);
-    } else if (limits) {
-      av1_get_rate_dist_lf_chroma(txb_costs, &pqData, &coeff_ctx, blk_pos,
-                                  diag_ctx, eob_rate, txb_ctx->dc_sign_ctx,
-                                  tmp_sign, bwl, tx_class, plane, coeff_sign,
-                                  &rd);
-    } else if (plane == 0) {
-      av1_get_rate_dist_def_luma(txb_costs, &pqData, &coeff_ctx, blk_pos, bwl,
-                                 tx_class, diag_ctx, eob_rate, &rd);
+                                tmp_sign, bwl, tx_class, coeff_sign, n_states,
+                                &rd);
+      } else {
+        int diag_ctx = get_nz_map_ctx_from_stats_lf_chroma(0, tx_class, plane);
+        for (int i = 0; i < n_states; i++) {
+          int base_ctx = get_lower_levels_lf_ctx_chroma(prev_levels[i], blk_pos,
+                                                        bwl, tx_class, plane);
+          int br_ctx =
+              get_br_lf_ctx_chroma(prev_levels[i], blk_pos, bwl, tx_class);
+          coeff_ctx.coef[i] = base_ctx - diag_ctx + (br_ctx << 4);
+        }
+        av1_get_rate_dist_lf_chroma(txb_costs, &pqData, &coeff_ctx, blk_pos,
+                                    diag_ctx, eob_rate, txb_ctx->dc_sign_ctx,
+                                    tmp_sign, bwl, tx_class, plane, coeff_sign,
+                                    n_states, &rd);
+      }
     } else {
-      av1_get_rate_dist_def_chroma(txb_costs, &pqData, &coeff_ctx, blk_pos, bwl,
-                                   tx_class, diag_ctx, eob_rate, plane,
-                                   tmp_sign[blk_pos], coeff_sign, &rd);
+      if (plane == 0) {
+        int diag_ctx = get_nz_map_ctx_from_stats(0, blk_pos, bwl, tx_class, 0);
+        for (int i = 0; i < n_states; i++) {
+          int base_ctx = get_lower_levels_ctx(prev_levels[i], blk_pos, bwl,
+                                              tx_class, plane);
+          int br_ctx = get_br_ctx(prev_levels[i], blk_pos, bwl, tx_class);
+          coeff_ctx.coef[i] = base_ctx - diag_ctx + (br_ctx << 4);
+        }
+        f_get_rate_dist_def_luma(txb_costs, &pqData, &coeff_ctx, blk_pos, bwl,
+                                 tx_class, diag_ctx, eob_rate, n_states, &rd);
+      } else {
+        int diag_ctx =
+            get_nz_map_ctx_from_stats_chroma(0, blk_pos, tx_class, plane);
+        for (int i = 0; i < n_states; i++) {
+          int base_ctx = get_lower_levels_ctx_chroma(prev_levels[i], blk_pos,
+                                                     bwl, tx_class, plane);
+          int br_ctx =
+              get_br_ctx_chroma(prev_levels[i], blk_pos, bwl, tx_class);
+          coeff_ctx.coef[i] = base_ctx - diag_ctx + (br_ctx << 4);
+        }
+        av1_get_rate_dist_def_chroma(
+            txb_costs, &pqData, &coeff_ctx, blk_pos, bwl, tx_class, diag_ctx,
+            eob_rate, plane, tmp_sign[blk_pos], coeff_sign, n_states, &rd);
+      }
     }
 
-    int try_eob = sharpness == 0;
-    av1_decide_states(prd, &rd, &pqData, limits, try_eob, rdmult, decision);
+    f_decide_states(prd, &rd, &pqData, n_states, limits, try_eob, rdmult,
+                    decision);
 
     // copy corresponding context from previous level buffer
-    for (int state = 0; state < TOTALSTATES && scan_pos != first_scan_pos;
+    for (int state = 0; state < n_states && scan_pos != first_scan_pos;
          state++) {
       int prevId = decision[state].prevId;
       if (prevId >= 0)
@@ -1489,7 +1771,7 @@
     }
 
     // update levels_buf
-    for (int state = 0; state < TOTALSTATES && scan_pos != 0; state++) {
+    for (int state = 0; state < n_states && scan_pos != 0; state++) {
       set_levels_buf(decision[state].prevId, decision[state].absLevel,
                      levels[state], scan, first_scan_pos, scan_pos, bwl,
                      sharpness);
@@ -1539,16 +1821,18 @@
   }
 }
 
-int av1_find_best_path_c(const struct tcq_node_t *trellis, const int16_t *scan,
-                         const int32_t *dequant, const qm_val_t *iqmatrix,
-                         const tran_low_t *tcoeff, int first_scan_pos,
-                         int log_scale, tran_low_t *qcoeff, tran_low_t *dqcoeff,
-                         int *min_rate, int64_t *min_cost) {
+int av1_find_best_path_c(const struct tcq_node_t *trellis, int n_states_log2,
+                         const int16_t *scan, const int32_t *dequant,
+                         const qm_val_t *iqmatrix, const tran_low_t *tcoeff,
+                         int first_scan_pos, int log_scale, tran_low_t *qcoeff,
+                         tran_low_t *dqcoeff, int *min_rate,
+                         int64_t *min_cost) {
   // Select best trellis state.
+  int n_states = 1 << n_states_log2;
   int64_t min_path_cost = INT64_MAX;
   int trel_min_rate = INT32_MAX;
   int prev_id = -2;
-  for (int state = 0; state < TOTALSTATES; state++) {
+  for (int state = 0; state < n_states; state++) {
     const tcq_node_t *decision = &trellis[state];
     if (decision->rdCost < min_path_cost) {
       prev_id = state;
@@ -1563,7 +1847,8 @@
     int dqv = dequant[0];
     int dqv_ac = dequant[1];
     for (; prev_id >= 0; scan_pos++) {
-      const tcq_node_t *decision = &trellis[scan_pos * TOTALSTATES + prev_id];
+      const tcq_node_t *decision =
+          &trellis[(scan_pos << n_states_log2) + prev_id];
       prev_id = decision->prevId;
       int abs_level = decision->absLevel;
       int blk_pos = scan[scan_pos];
@@ -1579,7 +1864,8 @@
     }
   } else {
     for (; prev_id >= 0; scan_pos++) {
-      const tcq_node_t *decision = &trellis[scan_pos * TOTALSTATES + prev_id];
+      const tcq_node_t *decision =
+          &trellis[(scan_pos << n_states_log2) + prev_id];
       prev_id = decision->prevId;
       int abs_level = decision->absLevel;
       int blk_pos = scan[scan_pos];
@@ -1676,8 +1962,9 @@
 
   // getting context from previous level buf, updating levels on current level
   // buf. initialization all value by 0, since we update every position.
+  int n_states_log2 = cm->features.tcq_mode == TCQ_8ST ? 3 : 2;
   int bufsize = (width + 4) * (height + 4) + TX_PAD_END;
-  int mem_tcq_sz = sizeof(uint8_t) * bufsize * 2 * TOTALSTATES;
+  int mem_tcq_sz = sizeof(uint8_t) * bufsize * (2 << n_states_log2);
   uint8_t *mem_tcq = (uint8_t *)malloc(mem_tcq_sz);
   if (!mem_tcq) {
     exit(1);
@@ -1691,10 +1978,10 @@
   int si = eob - 1;
   // populate trellis
   assert(si < MAX_TRELLIS);
-  tcq_node_t trellis[MAX_TRELLIS][TOTALSTATES];
+  tcq_node_t trellis[MAX_TRELLIS * TCQ_MAX_STATES];
 
   // Ping-pong buffers for diagonal contexts.
-  tcq_ctx_t tcq_ctx[2 * TOTALSTATES];
+  tcq_ctx_t tcq_ctx[2 * TCQ_MAX_STATES];
 
   // Precalc block eob rate.
   uint16_t block_eob_rate[MAX_TRELLIS];
@@ -1792,39 +2079,64 @@
   }
 #endif  // CONFIG_TXFMBLK_LOGS || CONFIG_COEFF_LOGS
 
+  // Collect TCQ related parameters.
+  tcq_param_t param;
+  int log_scale = av1_get_tx_scale(tx_size) + 1;
+  param.n_states_log2 = n_states_log2;
+  param.n_states = 1 << param.n_states_log2;
+  param.plane = plane;
+  param.tx_size = tx_size;
+  param.tx_class = tx_class;
+  param.sharpness = sharpness;
+  param.rdmult = rdmult;
+  param.log_scale = log_scale;
+  param.scan = scan;
+  param.tmp_sign = xd->tmp_sign;
+  param.qcoeff = qcoeff;
+  param.tcoeff = tcoeff;
+  param.quant = quant;
+  param.dequant = dequant;
+  param.iqmatrix = iqmatrix;
+  param.block_eob_rate = block_eob_rate;
+  param.txb_ctx = txb_ctx;
+  param.txb_costs = txb_costs;
+
   // Start of TCQ
   int first_scan_pos = si;
-  int log_scale = av1_get_tx_scale(tx_size) + 1;
-  trellis_first_pos(first_scan_pos, plane, tx_size, tx_class, xd->tmp_sign,
-                    sharpness, &tcq_lev, trellis, qcoeff, rdmult, log_scale,
-                    scan, tcoeff, dequant, quant, iqmatrix, block_eob_rate,
-                    txb_ctx, txb_costs);
+  trellis_first_pos(&param, first_scan_pos, &tcq_lev, trellis);
+
   int scan_hi = first_scan_pos - 1;
 
   if (scan_hi >= 0) {
     if (plane == 0 && tx_class == TX_CLASS_2D) {
       const int scan_lf_start = 9;
-      while (scan_hi > scan_lf_start) {
-        int blk_pos = scan[scan_hi];
-        int row = blk_pos >> bwl;
-        int col = blk_pos - (row << bwl);
-        int inc = AOMMIN(height - 1 - row, col);
-        int scan_lo = AOMMAX(scan_lf_start + 1, scan_hi - inc);
-        trellis_loop_diagonal(scan_hi, scan_lo, 0, tx_size, TX_CLASS_2D, 0,
-                              sharpness, &tcq_lev, tcq_ctx, trellis, qcoeff,
-                              rdmult, log_scale, scan, tcoeff, dequant, quant,
-                              iqmatrix, block_eob_rate, txb_ctx, txb_costs);
-        scan_hi = scan_lo - 1;
+      if (param.n_states == 4) {
+        while (scan_hi > scan_lf_start) {
+          int blk_pos = scan[scan_hi];
+          int row = blk_pos >> bwl;
+          int col = blk_pos - (row << bwl);
+          int inc = AOMMIN(height - 1 - row, col);
+          int scan_lo = AOMMAX(scan_lf_start + 1, scan_hi - inc);
+          trellis_loop_diagonal_st4(&param, scan_hi, scan_lo, &tcq_lev, tcq_ctx,
+                                    trellis);
+          scan_hi = scan_lo - 1;
+        }
+        trellis_loop_lf_st4(&param, scan_hi, 0, &tcq_lev, trellis);
+      } else {  // n_states == 8
+        while (scan_hi > scan_lf_start) {
+          int blk_pos = scan[scan_hi];
+          int row = blk_pos >> bwl;
+          int col = blk_pos - (row << bwl);
+          int inc = AOMMIN(height - 1 - row, col);
+          int scan_lo = AOMMAX(scan_lf_start + 1, scan_hi - inc);
+          trellis_loop_diagonal_st8(&param, scan_hi, scan_lo, &tcq_lev, tcq_ctx,
+                                    trellis);
+          scan_hi = scan_lo - 1;
+        }
+        trellis_loop_lf_st8(&param, scan_hi, 0, &tcq_lev, trellis);
       }
-      trellis_loop_lf(scan_hi, 0, plane, tx_size, tx_class, xd->tmp_sign,
-                      sharpness, &tcq_lev, trellis, qcoeff, rdmult, log_scale,
-                      scan, tcoeff, dequant, quant, iqmatrix, block_eob_rate,
-                      txb_ctx, txb_costs);
     } else {
-      trellis_loop(first_scan_pos, scan_hi, 0, plane, tx_size, tx_class,
-                   xd->tmp_sign, sharpness, &tcq_lev, trellis, qcoeff, rdmult,
-                   log_scale, scan, tcoeff, dequant, quant, iqmatrix,
-                   block_eob_rate, txb_ctx, txb_costs);
+      trellis_loop(&param, first_scan_pos, scan_hi, 0, &tcq_lev, trellis);
     }
   }
 
@@ -1833,8 +2145,8 @@
   // find best path
   int min_rate = INT32_MAX;
   int64_t min_path_cost = INT64_MAX;
-  eob = av1_find_best_path(&trellis[0][0], scan, dequant, iqmatrix, tcoeff,
-                           first_scan_pos, log_scale, qcoeff, dqcoeff,
+  eob = av1_find_best_path(trellis, n_states_log2, scan, dequant, iqmatrix,
+                           tcoeff, first_scan_pos, log_scale, qcoeff, dqcoeff,
                            &min_rate, &min_path_cost);
 
 #if CONFIG_CONTEXT_DERIVATION
diff --git a/av1/encoder/trellis_quant.h b/av1/encoder/trellis_quant.h
index e315689..30e5673 100644
--- a/av1/encoder/trellis_quant.h
+++ b/av1/encoder/trellis_quant.h
@@ -52,17 +52,38 @@
 } prequant_t;
 
 typedef struct tcq_rate_t {
-  int32_t rate[2 * TOTALSTATES];
-  int32_t rate_zero[TOTALSTATES];
+  int32_t rate[2 * TCQ_MAX_STATES];
+  int32_t rate_zero[TCQ_MAX_STATES];
   int32_t rate_eob[2];
 } tcq_rate_t;
 
 typedef struct tcq_coeff_ctx_t {
-  uint8_t coef[TOTALSTATES];
+  uint8_t coef[TCQ_MAX_STATES];
   uint8_t coef_eob;
   uint8_t pad[3];
 } tcq_coeff_ctx_t;
 
+typedef struct tcq_param_t {
+  int n_states;
+  int n_states_log2;
+  int plane;
+  TX_SIZE tx_size;
+  TX_CLASS tx_class;
+  int sharpness;
+  int64_t rdmult;
+  int log_scale;
+  const int16_t *scan;
+  const int32_t *tmp_sign;
+  const tran_low_t *qcoeff;
+  const tran_low_t *tcoeff;
+  const int32_t *quant;
+  const int32_t *dequant;
+  const qm_val_t *iqmatrix;
+  const uint16_t *block_eob_rate;
+  const TXB_CTX *txb_ctx;
+  const LV_MAP_COEFF_COST *txb_costs;
+} tcq_param_t;
+
 static AOM_FORCE_INLINE int get_low_range(int abs_qc, int lf) {
   int base_levels = lf ? 6 : 4;
   int parity = abs_qc & 1;
diff --git a/av1/encoder/x86/trellis_quant_avx2.c b/av1/encoder/x86/trellis_quant_avx2.c
index d1059ea..6563257 100644
--- a/av1/encoder/x86/trellis_quant_avx2.c
+++ b/av1/encoder/x86/trellis_quant_avx2.c
@@ -21,19 +21,36 @@
 #include "aom_dsp/x86/synonyms.h"
 #include "aom_dsp/x86/synonyms_avx2.h"
 
+// av1_decide_states_*() constants.
+static const int32_t kShuffle[8] = { 0, 2, 1, 3, 5, 7, 4, 6 };
+static const int32_t kPrevId[TCQ_MAX_STATES / 4][8] = {
+  { 0, 0 << 24, 0, 1 << 24, 0, 2 << 24, 0, 3 << 24 },
+  { 0, 4 << 24, 0, 5 << 24, 0, 6 << 24, 0, 7 << 24 },
+};
+
+// av1_calc_lf_ctx_*() constants.
+// Neighbor mask for calculating context sum (base/mid).
+#define M MAX_VAL_BR_CTX
+static const int8_t kNbrMask[4][32] = {
+  { 5, 5, 5, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  // diag 0
+    M, M, 0, M, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+  { 0, 5, 5, 0, 5, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0,  // diag 1
+    0, M, M, 0, 0, M, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+  { 0, 0, 5, 5, 0, 0, 5, 5, 5, 0, 0, 0, 0, 0, 0, 0,  // diag 2
+    0, 0, M, M, 0, 0, 0, M, 0, 0, 0, 0, 0, 0, 0, 0 },
+  { 0, 0, 0, 5, 5, 0, 0, 0, 5, 5, 5, 0, 0, 0, 0, 0,  // diag 3
+    0, 0, 0, M, M, 0, 0, 0, 0, M, 0, 0, 0, 0, 0, 0 },
+};
+static const int8_t kMaxCtx[16] = { 8, 6, 6, 4, 4, 4, 4, 4,
+                                    4, 4, 4, 4, 4, 4, 4, 4 };
+static const int8_t kScanDiag[MAX_LF_SCAN] = { 0, 1, 1, 2, 2, 2, 3, 3, 3, 3 };
+
 void av1_decide_states_avx2(const struct tcq_node_t *prev,
                             const struct tcq_rate_t *rd,
-                            const struct prequant_t *pq, int limits,
-                            int try_eob, int64_t rdmult,
+                            const struct prequant_t *pq, int n_states,
+                            int limits, int try_eob, int64_t rdmult,
                             struct tcq_node_t *decision) {
   (void)limits;
-  static const int32_t kShuffle[8] = { 0, 2, 1, 3, 5, 7, 4, 6 };
-  static const int32_t kPrevId[TOTALSTATES / 4][8] = {
-    { 0, 0 << 24, 0, 1 << 24, 0, 2 << 24, 0, 3 << 24 },
-#if MORESTATES
-    { 0, 4 << 24, 0, 5 << 24, 0, 6 << 24, 0, 7 << 24 },
-#endif
-  };
   assert((rdmult >> 32) == 0);
   assert(sizeof(tcq_node_t) == 16);
 
@@ -50,7 +67,10 @@
   __m256i abslev0033 = _mm256_unpacklo_epi32(c_zero, abslev00223311);
   __m256i abslev2211 = _mm256_unpackhi_epi32(c_zero, abslev00223311);
 
-  for (int i = 0; i < TOTALSTATES / 4; i++) {
+  __m256i *out_a = (__m256i *)&decision[0];
+  __m256i *out_b = (__m256i *)&decision[n_states >> 1];
+
+  for (int i = 0; i < n_states >> 2; i++) {
     // Load distortion.
     __m256i dist = _mm256_lddqu_si256((__m256i *)&pq->deltaDist[0]);
     dist = _mm256_slli_epi64(dist, RDDIV_BITS);
@@ -109,8 +129,8 @@
     __m256i rdcost3164 = _mm256_shuffle_epi32(rdcost1346, 0x4E);
     __m256i rate3164 = _mm256_shuffle_epi32(rate1346, 0x4E);
     __m256i use_odd = _mm256_cmpgt_epi64(rdcost0257, rdcost3164);
-    __m256i prev_id = _mm256_lddqu_si256((__m256i *)kPrevId[i]);
     __m256i use_odd_1 = _mm256_slli_epi64(_mm256_srli_epi64(use_odd, 63), 56);
+    __m256i prev_id = _mm256_lddqu_si256((__m256i *)kPrevId[i]);
     prev_id = _mm256_xor_si256(prev_id, use_odd_1);
     __m256i rdcost_best = _mm256_blendv_epi8(rdcost0257, rdcost3164, use_odd);
     __m256i rate_best = _mm256_blendv_epi8(rate0257, rate3164, use_odd);
@@ -141,16 +161,134 @@
     info_best = _mm256_or_si256(info_best, prev_id);
     __m256i info01 = _mm256_unpacklo_epi64(rdcost_best, info_best);
     __m256i info23 = _mm256_unpackhi_epi64(rdcost_best, info_best);
-#if MORESTATES
-    _mm256_storeu_si256((__m256i *)&decision[6 * i], info01);
-    _mm256_storeu_si256((__m256i *)&decision[4 - (2 * i)], info23);
-#else
-    _mm256_storeu_si256((__m256i *)&decision[0], info01);
-    _mm256_storeu_si256((__m256i *)&decision[2], info23);
-#endif
+    _mm256_storeu_si256(out_a, info01);
+    _mm256_storeu_si256(out_b, info23);
+    out_a = (__m256i *)&decision[6];
+    out_b = (__m256i *)&decision[2];
   }
 }
 
+void av1_decide_states_st4_avx2(const struct tcq_node_t *prev,
+                                const struct tcq_rate_t *rd,
+                                const struct prequant_t *pq, int n_states,
+                                int limits, int try_eob, int64_t rdmult,
+                                struct tcq_node_t *decision) {
+  (void)limits;
+  (void)n_states;
+  assert(n_states == 4);
+  assert((rdmult >> 32) == 0);
+  assert(sizeof(tcq_node_t) == 16);
+
+  int i = 0;
+
+  __m256i c_rdmult = _mm256_set1_epi64x(rdmult);
+  __m256i c_round = _mm256_set1_epi64x(1 << (AV1_PROB_COST_SHIFT - 1));
+  __m256i c_zero = _mm256_setzero_si256();
+
+  // Gather absolute coeff level for 4 possible quant options.
+  __m128i abslev0123 = _mm_lddqu_si128((__m128i *)pq->absLevel);
+  __m256i abslev0231 =
+      _mm256_castsi128_si256(_mm_shuffle_epi32(abslev0123, 0x78));
+  __m256i abslev02023131 = _mm256_permute4x64_epi64(abslev0231, 0x50);
+  __m256i abslev00223311 = _mm256_shuffle_epi32(abslev02023131, 0x50);
+  __m256i abslev0033 = _mm256_unpacklo_epi32(c_zero, abslev00223311);
+  __m256i abslev2211 = _mm256_unpackhi_epi32(c_zero, abslev00223311);
+
+  // Load distortion.
+  __m256i dist = _mm256_lddqu_si256((__m256i *)&pq->deltaDist[0]);
+  dist = _mm256_slli_epi64(dist, RDDIV_BITS);
+  __m256i dist0033 = _mm256_permute4x64_epi64(dist, 0xF0);
+  __m256i dist2211 = _mm256_permute4x64_epi64(dist, 0x5A);
+
+  // Calc rate-distortion costs for each pair of even/odd quant.
+  // Separate candidates into even and odd quant decisions
+  // Even indexes: { 0, 2, 5, 7 }. Odd: { 1, 3, 4, 6 }.
+  __m256i rates = _mm256_lddqu_si256((__m256i *)&rd->rate[8 * i]);
+  __m256i permute_mask = _mm256_lddqu_si256((__m256i *)kShuffle);
+  __m256i rate02135746 = _mm256_permutevar8x32_epi32(rates, permute_mask);
+  __m256i rate0257 = _mm256_unpacklo_epi32(rate02135746, c_zero);
+  __m256i rate1346 = _mm256_unpackhi_epi32(rate02135746, c_zero);
+  __m256i rdcost0257 = _mm256_mul_epu32(c_rdmult, rate0257);
+  __m256i rdcost1346 = _mm256_mul_epu32(c_rdmult, rate1346);
+  rdcost0257 = _mm256_add_epi64(rdcost0257, c_round);
+  rdcost1346 = _mm256_add_epi64(rdcost1346, c_round);
+  rdcost0257 = _mm256_srli_epi64(rdcost0257, AV1_PROB_COST_SHIFT);
+  rdcost1346 = _mm256_srli_epi64(rdcost1346, AV1_PROB_COST_SHIFT);
+  rdcost0257 = _mm256_add_epi64(rdcost0257, dist0033);
+  rdcost1346 = _mm256_add_epi64(rdcost1346, dist2211);
+
+  // Calc rd-cost for zero quant.
+  __m256i ratezero =
+      _mm256_castsi128_si256(_mm_lddqu_si128((__m128i *)&rd->rate_zero[4 * i]));
+  ratezero = _mm256_permute4x64_epi64(ratezero, 0x50);
+  ratezero = _mm256_unpacklo_epi32(ratezero, c_zero);
+  __m256i rdcostzero = _mm256_mul_epu32(c_rdmult, ratezero);
+  rdcostzero = _mm256_add_epi64(rdcostzero, c_round);
+  rdcostzero = _mm256_srli_epi64(rdcostzero, AV1_PROB_COST_SHIFT);
+
+  // Add previous state rdCost to rdcostzero
+  __m256i state01 = _mm256_lddqu_si256((__m256i *)&prev[4 * i]);
+  __m256i state23 = _mm256_lddqu_si256((__m256i *)&prev[4 * i + 2]);
+  __m256i state02 = _mm256_permute2x128_si256(state01, state23, 0x20);
+  __m256i state13 = _mm256_permute2x128_si256(state01, state23, 0x31);
+  __m256i prevrd0123 = _mm256_unpacklo_epi64(state02, state13);
+  __m256i prevrate0123 = _mm256_unpackhi_epi64(state02, state13);
+  prevrate0123 = _mm256_slli_epi64(prevrate0123, 32);
+  prevrate0123 = _mm256_srli_epi64(prevrate0123, 32);
+
+  // Compare rd costs (Zero vs Even).
+  __m256i use_zero = _mm256_cmpgt_epi64(rdcost0257, rdcostzero);
+  rdcost0257 = _mm256_blendv_epi8(rdcost0257, rdcostzero, use_zero);
+  rate0257 = _mm256_blendv_epi8(rate0257, ratezero, use_zero);
+  __m256i abslev_even = _mm256_andnot_si256(use_zero, abslev0033);
+
+  // Add previous state rdCost to current rdcost
+  rdcost0257 = _mm256_add_epi64(rdcost0257, prevrd0123);
+  rdcost1346 = _mm256_add_epi64(rdcost1346, prevrd0123);
+  rate0257 = _mm256_add_epi64(rate0257, prevrate0123);
+  rate1346 = _mm256_add_epi64(rate1346, prevrate0123);
+
+  // Compare rd costs (Even vs Odd).
+  __m256i rdcost3164 = _mm256_shuffle_epi32(rdcost1346, 0x4E);
+  __m256i rate3164 = _mm256_shuffle_epi32(rate1346, 0x4E);
+  __m256i use_odd = _mm256_cmpgt_epi64(rdcost0257, rdcost3164);
+  __m256i use_odd_1 = _mm256_slli_epi64(_mm256_srli_epi64(use_odd, 63), 56);
+  __m256i prev_id = _mm256_lddqu_si256((__m256i *)kPrevId[i]);
+  prev_id = _mm256_xor_si256(prev_id, use_odd_1);
+  __m256i rdcost_best = _mm256_blendv_epi8(rdcost0257, rdcost3164, use_odd);
+  __m256i rate_best = _mm256_blendv_epi8(rate0257, rate3164, use_odd);
+  __m256i abslev_best = _mm256_blendv_epi8(abslev_even, abslev2211, use_odd);
+
+  // Compare rd costs (best vs new eob).
+  __m256i rate_eob = _mm256_castsi128_si256(_mm_loadu_si64(rd->rate_eob));
+  rate_eob = _mm256_unpacklo_epi32(rate_eob, c_zero);
+  __m256i rdcost_eob = _mm256_mul_epu32(c_rdmult, rate_eob);
+  rdcost_eob = _mm256_add_epi64(rdcost_eob, c_round);
+  rdcost_eob = _mm256_srli_epi64(rdcost_eob, AV1_PROB_COST_SHIFT);
+  __m256i dist_eob = _mm256_unpacklo_epi64(dist0033, dist2211);
+  rdcost_eob = _mm256_add_epi64(rdcost_eob, dist_eob);
+  __m128i mask_eob0 = _mm_set1_epi64x((int64_t)-try_eob);
+  __m256i mask_eob = _mm256_inserti128_si256(c_zero, mask_eob0, 0);
+  __m256i use_eob = _mm256_cmpgt_epi64(rdcost_best, rdcost_eob);
+  use_eob = _mm256_and_si256(use_eob, mask_eob);
+  __m256i use_eob_1 = _mm256_slli_epi64(use_eob, 56);
+  prev_id = _mm256_or_si256(prev_id, use_eob_1);
+  rdcost_best = _mm256_blendv_epi8(rdcost_best, rdcost_eob, use_eob);
+  rate_best = _mm256_blendv_epi8(rate_best, rate_eob, use_eob);
+  __m256i abslev_eob = _mm256_unpacklo_epi64(abslev0033, abslev2211);
+  abslev_best = _mm256_blendv_epi8(abslev_best, abslev_eob, use_eob);
+
+  // Pack and store state info.
+  __m256i info_best = _mm256_or_si256(rate_best, abslev_best);
+  info_best = _mm256_or_si256(info_best, prev_id);
+  __m256i info01 = _mm256_unpacklo_epi64(rdcost_best, info_best);
+  __m256i info23 = _mm256_unpackhi_epi64(rdcost_best, info_best);
+  __m256i *out_a = (__m256i *)&decision[0];
+  __m256i *out_b = (__m256i *)&decision[2];
+  _mm256_storeu_si256(out_a, info01);
+  _mm256_storeu_si256(out_b, info23);
+}
+
 void av1_pre_quant_avx2(tran_low_t tqc, struct prequant_t *pqData,
                         const int32_t *quant_ptr, int dqv, int log_scale,
                         int scan_pos) {
@@ -197,10 +335,10 @@
   _mm256_storeu_si256((__m256i *)pqData->deltaDist, dist);
 }
 
-void av1_update_states_avx2(tcq_node_t *decision, int scan_idx,
+void av1_update_states_avx2(tcq_node_t *decision, int scan_idx, int n_states,
                             const struct tcq_ctx_t *cur_ctx,
                             struct tcq_ctx_t *nxt_ctx) {
-  for (int i = 0; i < TOTALSTATES; i++) {
+  for (int i = 0; i < n_states; i++) {
     int prevId = decision[i].prevId;
     int absLevel = decision[i].absLevel;
     if (prevId >= 0) {
@@ -392,7 +530,7 @@
                                      const struct prequant_t *pq,
                                      const tcq_coeff_ctx_t *coeff_ctx,
                                      int blk_pos, int bwl, TX_CLASS tx_class,
-                                     int diag_ctx, int eob_rate,
+                                     int diag_ctx, int eob_rate, int n_states,
                                      struct tcq_rate_t *rd) {
   (void)bwl;
   const int32_t(*cost_zero)[SIG_COEF_CONTEXTS] = txb_costs->base_cost_zero;
@@ -419,11 +557,9 @@
   __m256i ratez_0123 = _mm256_unpacklo_epi64(ratez_dq0, ratez_dq1);
   _mm_storeu_si128((__m128i *)&rd->rate_zero[0],
                    _mm256_castsi256_si128(ratez_0123));
-#if MORESTATES
   __m256i ratez_4567 = _mm256_unpackhi_epi64(ratez_dq0, ratez_dq1);
   _mm_storeu_si128((__m128i *)&rd->rate_zero[4],
                    _mm256_castsi256_si128(ratez_4567));
-#endif
 
   // Calc coeff_base rate.
   int idx = AOMMIN(pq->qIdx - 1, 4);
@@ -432,7 +568,7 @@
   __m256i base_ctx = _mm256_slli_epi16(ctx16, 12);
   base_ctx = _mm256_srli_epi16(base_ctx, 12);
   base_ctx = _mm256_add_epi16(base_ctx, diag);
-  for (int i = 0; i < TOTALSTATES / 4; i++) {
+  for (int i = 0; i < (n_states >> 2); i++) {
     int ctx0 = _mm256_extract_epi16(base_ctx, 0);
     int ctx1 = _mm256_extract_epi16(base_ctx, 1);
     int ctx2 = _mm256_extract_epi16(base_ctx, 2);
@@ -460,7 +596,7 @@
 
   // Calc coeff mid and high range cost.
   if (idx > 0) {
-    for (int i = 0; i < TOTALSTATES; i++) {
+    for (int i = 0; i < n_states; i++) {
       int a0 = i & 2 ? 1 : 0;
       int a1 = a0 + 2;
       int mid_cost0 = get_mid_cost_def(absLevel[a0], coeff_ctx->coef[i],
@@ -479,23 +615,96 @@
   }
 }
 
-void av1_calc_lf_ctx_avx2(const struct tcq_lf_ctx_t *lf_ctx, int scan_pos,
-                          struct tcq_coeff_ctx_t *coeff_ctx) {
-#define M MAX_VAL_BR_CTX
-  // Neighbor mask for calculating context sum (base/mid).
-  static const int8_t kNbrMask[4][32] = {
-    { 5, 5, 5, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  // diag 0
-      M, M, 0, M, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
-    { 0, 5, 5, 0, 5, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0,  // diag 1
-      0, M, M, 0, 0, M, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
-    { 0, 0, 5, 5, 0, 0, 5, 5, 5, 0, 0, 0, 0, 0, 0, 0,  // diag 2
-      0, 0, M, M, 0, 0, 0, M, 0, 0, 0, 0, 0, 0, 0, 0 },
-    { 0, 0, 0, 5, 5, 0, 0, 0, 5, 5, 5, 0, 0, 0, 0, 0,  // diag 3
-      0, 0, 0, M, M, 0, 0, 0, 0, M, 0, 0, 0, 0, 0, 0 },
-  };
-  static const int8_t kMaxCtx[16] = { 8, 6, 6, 4, 4, 4, 4, 4,
-                                      4, 4, 4, 4, 4, 4, 4, 4 };
-  static const int8_t kScanDiag[MAX_LF_SCAN] = { 0, 1, 1, 2, 2, 2, 3, 3, 3, 3 };
+void av1_get_rate_dist_def_luma_st4_avx2(
+    const struct LV_MAP_COEFF_COST *txb_costs, const struct prequant_t *pq,
+    const tcq_coeff_ctx_t *coeff_ctx, int blk_pos, int bwl, TX_CLASS tx_class,
+    int diag_ctx, int eob_rate, int n_states, struct tcq_rate_t *rd) {
+  (void)bwl;
+  assert(n_states == 4);
+  n_states = 4;
+
+  const int32_t(*cost_zero)[SIG_COEF_CONTEXTS] = txb_costs->base_cost_zero;
+  const uint16_t(*cost_low_tbl)[SIG_COEF_CONTEXTS][DQ_CTXS][2] =
+      txb_costs->base_cost_low_tbl;
+  const uint16_t(*cost_eob_tbl)[SIG_COEF_CONTEXTS_EOB][2] =
+      txb_costs->base_eob_cost_tbl;
+  const tran_low_t *absLevel = pq->absLevel;
+
+  // Calc zero coeff costs.
+  __m256i zero = _mm256_setzero_si256();
+  __m256i cost_zero_dq0 =
+      _mm256_lddqu_si256((__m256i *)&cost_zero[0][diag_ctx]);
+  __m256i cost_zero_dq1 =
+      _mm256_lddqu_si256((__m256i *)&cost_zero[1][diag_ctx]);
+
+  __m256i coef_ctx = _mm256_castsi128_si256(_mm_loadu_si64(&coeff_ctx->coef));
+  __m256i ctx16 = _mm256_unpacklo_epi8(coef_ctx, zero);
+  __m256i ctx = _mm256_shuffle_epi32(ctx16, 0xD8);
+  __m256i ctx_dq0 = _mm256_unpacklo_epi16(ctx, zero);
+  __m256i ctx_dq1 = _mm256_unpackhi_epi16(ctx, zero);
+  __m256i ratez_dq0 = _mm256_permutevar8x32_epi32(cost_zero_dq0, ctx_dq0);
+  __m256i ratez_dq1 = _mm256_permutevar8x32_epi32(cost_zero_dq1, ctx_dq1);
+  __m256i ratez_0123 = _mm256_unpacklo_epi64(ratez_dq0, ratez_dq1);
+  _mm_storeu_si128((__m128i *)&rd->rate_zero[0],
+                   _mm256_castsi256_si128(ratez_0123));
+
+  // Calc coeff_base rate.
+  int idx = AOMMIN(pq->qIdx - 1, 4);
+  __m128i c_zero = _mm_setzero_si128();
+  __m256i diag = _mm256_set1_epi16(diag_ctx);
+  __m256i base_ctx = _mm256_slli_epi16(ctx16, 12);
+  base_ctx = _mm256_srli_epi16(base_ctx, 12);
+  base_ctx = _mm256_add_epi16(base_ctx, diag);
+  for (int i = 0; i < (n_states >> 2); i++) {
+    int ctx0 = _mm256_extract_epi16(base_ctx, 0);
+    int ctx1 = _mm256_extract_epi16(base_ctx, 1);
+    int ctx2 = _mm256_extract_epi16(base_ctx, 2);
+    int ctx3 = _mm256_extract_epi16(base_ctx, 3);
+    base_ctx = _mm256_bsrli_epi128(base_ctx, 8);
+    __m128i rate_01 = _mm_loadu_si64(&cost_low_tbl[idx][ctx0][0]);
+    __m128i rate_23 = _mm_loadu_si64(&cost_low_tbl[idx][ctx1][0]);
+    __m128i rate_45 = _mm_loadu_si64(&cost_low_tbl[idx][ctx2][1]);
+    __m128i rate_67 = _mm_loadu_si64(&cost_low_tbl[idx][ctx3][1]);
+    __m128i rate_0123 = _mm_unpacklo_epi32(rate_01, rate_23);
+    __m128i rate_4567 = _mm_unpacklo_epi32(rate_45, rate_67);
+    rate_0123 = _mm_unpacklo_epi16(rate_0123, c_zero);
+    rate_4567 = _mm_unpacklo_epi16(rate_4567, c_zero);
+    _mm_storeu_si128((__m128i *)&rd->rate[8 * i], rate_0123);
+    _mm_storeu_si128((__m128i *)&rd->rate[8 * i + 4], rate_4567);
+  }
+
+  // Calc coeff/eob cost.
+  int eob_ctx = coeff_ctx->coef_eob;
+  __m128i rate_eob_coef = _mm_loadu_si64(&cost_eob_tbl[idx][eob_ctx][0]);
+  rate_eob_coef = _mm_unpacklo_epi16(rate_eob_coef, c_zero);
+  __m128i rate_eob_position = _mm_set1_epi32(eob_rate);
+  __m128i rate_eob = _mm_add_epi32(rate_eob_coef, rate_eob_position);
+  _mm_storeu_si64(&rd->rate_eob[0], rate_eob);
+
+  // Calc coeff mid and high range cost.
+  if (idx > 0) {
+    for (int i = 0; i < n_states; i++) {
+      int a0 = i & 2 ? 1 : 0;
+      int a1 = a0 + 2;
+      int mid_cost0 = get_mid_cost_def(absLevel[a0], coeff_ctx->coef[i],
+                                       txb_costs, 0, 0, 0);
+      int mid_cost1 = get_mid_cost_def(absLevel[a1], coeff_ctx->coef[i],
+                                       txb_costs, 0, 0, 0);
+      rd->rate[2 * i] += mid_cost0;
+      rd->rate[2 * i + 1] += mid_cost1;
+    }
+    int eob_mid_cost0 = get_mid_cost_eob(blk_pos, 0, 0, absLevel[0], 0, 0,
+                                         txb_costs, tx_class, 0, 0);
+    int eob_mid_cost1 = get_mid_cost_eob(blk_pos, 0, 0, absLevel[2], 0, 0,
+                                         txb_costs, tx_class, 0, 0);
+    rd->rate_eob[0] += eob_mid_cost0;
+    rd->rate_eob[1] += eob_mid_cost1;
+  }
+}
+
+void av1_calc_lf_ctx_st4_avx2(const struct tcq_lf_ctx_t *lf_ctx, int scan_pos,
+                              struct tcq_coeff_ctx_t *coeff_ctx) {
+  int n_states = 4;
 
   int diag = kScanDiag[scan_pos];
   __m256i zero = _mm256_setzero_si256();
@@ -503,7 +712,7 @@
   __m256i base_mask = _mm256_permute2x128_si256(nbr_mask, nbr_mask, 0);
   __m256i mid_mask = _mm256_permute2x128_si256(nbr_mask, nbr_mask, 0x11);
 
-  for (int st = 0; st < TOTALSTATES; st += 4) {
+  for (int st = 0; st < n_states; st += 4) {
     // Load previously decoded LF context values.
     __m256i last01 = _mm256_lddqu_si256((__m256i *)&lf_ctx[st]);
     __m256i last23 = _mm256_lddqu_si256((__m256i *)&lf_ctx[st + 2]);
@@ -546,17 +755,86 @@
     ctx16 = _mm256_add_epi16(ctx16, mid_ctx_offset);
     __m128i ctx8 = _mm256_castsi256_si128(ctx16);
     ctx8 = _mm_packus_epi16(ctx8, ctx8);
-    _mm_storeu_si64(&coeff_ctx->coef[st], ctx8);
+#if 1
+    // Older compilers don't implement _mm_storeu_si32()
+    _mm_store_ss((float *)&coeff_ctx->coef[st], _mm_castsi128_ps(ctx8));
+#else
+    _mm_storeu_si32(&coeff_ctx->coef[st], ctx8);
+#endif
   }
 }
 
-void av1_update_lf_ctx_avx2(const struct tcq_node_t *decision,
+void av1_calc_lf_ctx_st8_avx2(const struct tcq_lf_ctx_t *lf_ctx, int scan_pos,
+                              struct tcq_coeff_ctx_t *coeff_ctx) {
+  int n_states = 8;
+
+  int diag = kScanDiag[scan_pos];
+  __m256i zero = _mm256_setzero_si256();
+  __m256i nbr_mask = _mm256_lddqu_si256((__m256i *)kNbrMask[diag]);
+  __m256i base_mask = _mm256_permute2x128_si256(nbr_mask, nbr_mask, 0);
+  __m256i mid_mask = _mm256_permute2x128_si256(nbr_mask, nbr_mask, 0x11);
+
+  for (int st = 0; st < n_states; st += 4) {
+    // Load previously decoded LF context values.
+    __m256i last01 = _mm256_lddqu_si256((__m256i *)&lf_ctx[st]);
+    __m256i last23 = _mm256_lddqu_si256((__m256i *)&lf_ctx[st + 2]);
+
+    // Calc base ctx neighbor sum.
+    __m256i base01 = _mm256_min_epu8(last01, base_mask);
+    __m256i base23 = _mm256_min_epu8(last23, base_mask);
+    __m256i base01_sum = _mm256_sad_epu8(base01, zero);
+    __m256i base23_sum = _mm256_sad_epu8(base23, zero);
+    __m256i base_sum =
+        _mm256_hadd_epi32(base01_sum, base23_sum);  // B0 B0 B2 B2 B1 B1 B3 B3
+
+    // Calc mid ctx neighbor sum.
+    __m256i mid01 = _mm256_min_epu8(last01, mid_mask);
+    __m256i mid23 = _mm256_min_epu8(last23, mid_mask);
+    __m256i mid01_sum = _mm256_sad_epu8(mid01, zero);
+    __m256i mid23_sum = _mm256_sad_epu8(mid23, zero);
+    __m256i mid_sum =
+        _mm256_hadd_epi32(mid01_sum, mid23_sum);  // M0 M0 M2 M2 M1 M1 M3 M3
+
+    // Context calc; combine and reduce to 8 bits.
+    __m256i base_mid =
+        _mm256_hadd_epi32(base_sum, mid_sum);  // B0B2 M0M2 B1B3 M1M3
+    base_mid = _mm256_hadd_epi16(
+        base_mid, zero);  // reduce to 16 bits B0B2 M0M2 - - B1B3 M1M3 - -
+    base_mid = _mm256_avg_epu16(base_mid, zero);  // x = (x + 1) >> 1
+    base_mid = _mm256_shufflelo_epi16(
+        base_mid, 0xD8);  // shuffle B0M0 B2M2 - - B1M1 B3M3 - -
+    base_mid = _mm256_permute4x64_epi64(
+        base_mid, 0xD8);  // pack into lower half: B0M0 B2M2 B1M1 B3M3
+    base_mid = _mm256_shuffle_epi32(base_mid, 0xD8);  // B0M0 B1M1 B2M2 B3M3
+    __m256i six = _mm256_set1_epi16(6);
+    __m256i mid = _mm256_min_epi16(base_mid, six);
+    __m256i mid_sh4 = _mm256_slli_epi16(mid, 4);
+    __m256i base_max = _mm256_set1_epi16(kMaxCtx[scan_pos]);
+    __m256i base = _mm256_min_epi16(base_mid, base_max);
+    base_mid = _mm256_blend_epi16(base, mid_sh4, 0xAA);
+    __m256i ctx16 = _mm256_hadd_epi16(base_mid, base_mid);
+    __m256i mid_ctx_offset = _mm256_set1_epi16((scan_pos == 0) ? 0 : (7 << 4));
+    ctx16 = _mm256_add_epi16(ctx16, mid_ctx_offset);
+    __m128i ctx8 = _mm256_castsi256_si128(ctx16);
+    ctx8 = _mm_packus_epi16(ctx8, ctx8);
+#if 1
+    // Older compilers don't implement _mm_storeu_si32()
+    _mm_store_ss((float *)&coeff_ctx->coef[st], _mm_castsi128_ps(ctx8));
+#else
+    _mm_storeu_si32(&coeff_ctx->coef[st], ctx8);
+#endif
+  }
+}
+
+void av1_update_lf_ctx_avx2(const struct tcq_node_t *decision, int n_states,
                             struct tcq_lf_ctx_t *lf_ctx) {
-  __m256i upd_last_a;
-  __m256i upd_last_b;
-  __m256i upd_last_c;
-  __m256i upd_last_d;
-  for (int st = 0; st < TOTALSTATES; st += 2) {
+  __m256i c_zero = _mm256_setzero_si256();
+  __m256i upd_last_a = c_zero;
+  __m256i upd_last_b = c_zero;
+  __m256i upd_last_c = c_zero;
+  __m256i upd_last_d = c_zero;
+
+  for (int st = 0; st < n_states; st += 2) {
     int absLevel0 = decision[st].absLevel;
     int prevId0 = decision[st].prevId;
     int absLevel1 = decision[st + 1].absLevel;
@@ -580,26 +858,26 @@
     upd_last_b = upd_last_a;
     upd_last_a = upd01;
   }
-#if MORESTATES
-  _mm256_storeu_si256((__m256i *)lf_ctx[0].last, upd_last_d);
-  _mm256_storeu_si256((__m256i *)lf_ctx[2].last, upd_last_c);
-  _mm256_storeu_si256((__m256i *)lf_ctx[4].last, upd_last_b);
-  _mm256_storeu_si256((__m256i *)lf_ctx[6].last, upd_last_a);
-#else
-  (void)upd_last_d;
-  (void)upd_last_c;
-  _mm256_storeu_si256((__m256i *)lf_ctx[0].last, upd_last_b);
-  _mm256_storeu_si256((__m256i *)lf_ctx[2].last, upd_last_a);
-#endif
+  if (n_states == 4) {
+    (void)upd_last_d;
+    (void)upd_last_c;
+    _mm256_storeu_si256((__m256i *)lf_ctx[0].last, upd_last_b);
+    _mm256_storeu_si256((__m256i *)lf_ctx[2].last, upd_last_a);
+  } else {
+    _mm256_storeu_si256((__m256i *)lf_ctx[0].last, upd_last_d);
+    _mm256_storeu_si256((__m256i *)lf_ctx[2].last, upd_last_c);
+    _mm256_storeu_si256((__m256i *)lf_ctx[4].last, upd_last_b);
+    _mm256_storeu_si256((__m256i *)lf_ctx[6].last, upd_last_a);
+  }
 }
 
 void av1_get_rate_dist_lf_luma_avx2(const struct LV_MAP_COEFF_COST *txb_costs,
                                     const struct prequant_t *pq,
                                     const struct tcq_coeff_ctx_t *coeff_ctx,
                                     int blk_pos, int diag_ctx, int eob_rate,
-                                    int dc_sign_ctx, int32_t *tmp_sign, int bwl,
-                                    TX_CLASS tx_class, int coeff_sign,
-                                    struct tcq_rate_t *rd) {
+                                    int dc_sign_ctx, const int32_t *tmp_sign,
+                                    int bwl, TX_CLASS tx_class, int coeff_sign,
+                                    int n_states, struct tcq_rate_t *rd) {
 #define Z -1
   static const int8_t kShuf[2][32] = {
     { 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15,
@@ -636,18 +914,14 @@
   ratez = _mm256_permute4x64_epi64(ratez, 0x88);
   __m256i shuf1 = _mm256_lddqu_si256((__m256i *)kShuf[1]);
   ratez = _mm256_shuffle_epi8(ratez, shuf1);
-#if MORESTATES
   _mm256_storeu_si256((__m256i *)&rd->rate_zero[0], ratez);
-#else
-  _mm_storeu_si128((__m128i *)&rd->rate_zero[0], _mm256_castsi256_si128(ratez));
-#endif
 
   // Calc coeff_base rate.
   int idx = AOMMIN(pq->qIdx - 1, 8);
   __m128i c_zero = _mm_setzero_si128();
   __m256i diag = _mm256_set1_epi8(diag_ctx);
   base_ctx = _mm256_add_epi8(base_ctx, diag);
-  for (int i = 0; i < TOTALSTATES / 4; i++) {
+  for (int i = 0; i < (n_states >> 2); i++) {
     int ctx0 = _mm256_extract_epi8(base_ctx, 0);
     int ctx1 = _mm256_extract_epi8(base_ctx, 1);
     int ctx2 = _mm256_extract_epi8(base_ctx, 2);
@@ -680,7 +954,7 @@
   const bool dc_ver = (row == 0) && tx_class == TX_CLASS_VERT;
   const bool is_dc_coeff = dc_2dtx || dc_hor || dc_ver;
   if (is_dc_coeff) {
-    for (int i = 0; i < TOTALSTATES; i++) {
+    for (int i = 0; i < n_states; i++) {
       int a0 = i & 2 ? 1 : 0;
       int a1 = a0 + 2;
       int mid_cost0 = get_mid_cost_lf_dc(blk_pos, absLevel[a0], coeff_sign,
@@ -702,7 +976,134 @@
     rd->rate_eob[0] += eob_mid_cost0;
     rd->rate_eob[1] += eob_mid_cost1;
   } else if (idx > 4) {
-    for (int i = 0; i < TOTALSTATES; i++) {
+    for (int i = 0; i < n_states; i++) {
+      int a0 = i & 2 ? 1 : 0;
+      int a1 = a0 + 2;
+      int mid_cost0 =
+          get_mid_cost_lf(absLevel[a0], coeff_ctx->coef[i], txb_costs, plane);
+      int mid_cost1 =
+          get_mid_cost_lf(absLevel[a1], coeff_ctx->coef[i], txb_costs, plane);
+      rd->rate[2 * i] += mid_cost0;
+      rd->rate[2 * i + 1] += mid_cost1;
+    }
+    int t_sign = tmp_sign[blk_pos];
+    int eob_mid_cost0 =
+        get_mid_cost_eob(blk_pos, 1, 0, absLevel[0], coeff_sign, dc_sign_ctx,
+                         txb_costs, tx_class, t_sign, 0);
+    int eob_mid_cost1 =
+        get_mid_cost_eob(blk_pos, 1, 0, absLevel[2], coeff_sign, dc_sign_ctx,
+                         txb_costs, tx_class, t_sign, 0);
+    rd->rate_eob[0] += eob_mid_cost0;
+    rd->rate_eob[1] += eob_mid_cost1;
+  }
+}
+
+void av1_get_rate_dist_lf_luma_st4_avx2(
+    const struct LV_MAP_COEFF_COST *txb_costs, const struct prequant_t *pq,
+    const struct tcq_coeff_ctx_t *coeff_ctx, int blk_pos, int diag_ctx,
+    int eob_rate, int dc_sign_ctx, const int32_t *tmp_sign, int bwl,
+    TX_CLASS tx_class, int coeff_sign, int n_states, struct tcq_rate_t *rd) {
+  assert(n_states == 4);
+  n_states = 4;
+#define Z -1
+  static const int8_t kShuf[2][32] = {
+    { 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15,
+      0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15 },
+    { 0, 8,  Z, Z, 1, 9,  Z, Z, 2, 10, Z, Z, 3, 11, Z, Z,
+      4, 12, Z, Z, 5, 13, Z, Z, 6, 14, Z, Z, 7, 15, Z, Z }
+  };
+  const uint16_t(*cost_zero)[LF_SIG_COEF_CONTEXTS] =
+      txb_costs->base_lf_cost_zero;
+  const uint16_t(*cost_low_tbl)[LF_SIG_COEF_CONTEXTS][DQ_CTXS][2] =
+      txb_costs->base_lf_cost_low_tbl;
+  const uint16_t(*cost_eob_tbl)[SIG_COEF_CONTEXTS_EOB][2] =
+      txb_costs->base_lf_eob_cost_tbl;
+  const tran_low_t *absLevel = pq->absLevel;
+  const int plane = 0;
+
+  // Calc zero coeff costs.
+  __m256i cost_zero_dq0 =
+      _mm256_lddqu_si256((__m256i *)&cost_zero[0][diag_ctx]);
+  __m256i cost_zero_dq1 =
+      _mm256_lddqu_si256((__m256i *)&cost_zero[1][diag_ctx]);
+  __m256i shuf = _mm256_lddqu_si256((__m256i *)kShuf[0]);
+  cost_zero_dq0 = _mm256_shuffle_epi8(cost_zero_dq0, shuf);
+  cost_zero_dq1 = _mm256_shuffle_epi8(cost_zero_dq1, shuf);
+  __m256i cost_dq0 = _mm256_permute4x64_epi64(cost_zero_dq0, 0xD8);
+  __m256i cost_dq1 = _mm256_permute4x64_epi64(cost_zero_dq1, 0xD8);
+  __m256i ctx = _mm256_castsi128_si256(_mm_loadu_si64(&coeff_ctx->coef));
+  __m256i fifteen = _mm256_set1_epi8(15);
+  __m256i base_ctx = _mm256_and_si256(ctx, fifteen);
+  __m256i base_ctx1 = _mm256_permute4x64_epi64(base_ctx, 0);
+  __m256i ratez_dq0 = _mm256_shuffle_epi8(cost_dq0, base_ctx1);
+  __m256i ratez_dq1 = _mm256_shuffle_epi8(cost_dq1, base_ctx1);
+  __m256i ratez = _mm256_blend_epi16(ratez_dq0, ratez_dq1, 0xAA);
+  ratez = _mm256_permute4x64_epi64(ratez, 0x88);
+  __m256i shuf1 = _mm256_lddqu_si256((__m256i *)kShuf[1]);
+  ratez = _mm256_shuffle_epi8(ratez, shuf1);
+  _mm256_storeu_si256((__m256i *)&rd->rate_zero[0], ratez);
+
+  // Calc coeff_base rate.
+  int idx = AOMMIN(pq->qIdx - 1, 8);
+  __m128i c_zero = _mm_setzero_si128();
+  __m256i diag = _mm256_set1_epi8(diag_ctx);
+  base_ctx = _mm256_add_epi8(base_ctx, diag);
+  for (int i = 0; i < (n_states >> 2); i++) {
+    int ctx0 = _mm256_extract_epi8(base_ctx, 0);
+    int ctx1 = _mm256_extract_epi8(base_ctx, 1);
+    int ctx2 = _mm256_extract_epi8(base_ctx, 2);
+    int ctx3 = _mm256_extract_epi8(base_ctx, 3);
+    base_ctx = _mm256_bsrli_epi128(base_ctx, 4);
+    __m128i rate_01 = _mm_loadu_si64(&cost_low_tbl[idx][ctx0][0]);
+    __m128i rate_23 = _mm_loadu_si64(&cost_low_tbl[idx][ctx1][0]);
+    __m128i rate_45 = _mm_loadu_si64(&cost_low_tbl[idx][ctx2][1]);
+    __m128i rate_67 = _mm_loadu_si64(&cost_low_tbl[idx][ctx3][1]);
+    __m128i rate_0123 = _mm_unpacklo_epi32(rate_01, rate_23);
+    __m128i rate_4567 = _mm_unpacklo_epi32(rate_45, rate_67);
+    rate_0123 = _mm_unpacklo_epi16(rate_0123, c_zero);
+    rate_4567 = _mm_unpacklo_epi16(rate_4567, c_zero);
+    _mm_storeu_si128((__m128i *)&rd->rate[8 * i], rate_0123);
+    _mm_storeu_si128((__m128i *)&rd->rate[8 * i + 4], rate_4567);
+  }
+
+  // Calc coeff/eob cost.
+  int eob_ctx = coeff_ctx->coef_eob;
+  __m128i rate_eob_coef = _mm_loadu_si64(&cost_eob_tbl[idx][eob_ctx][0]);
+  rate_eob_coef = _mm_unpacklo_epi16(rate_eob_coef, c_zero);
+  __m128i rate_eob_position = _mm_set1_epi32(eob_rate);
+  __m128i rate_eob = _mm_add_epi32(rate_eob_coef, rate_eob_position);
+  _mm_storeu_si64(&rd->rate_eob[0], rate_eob);
+
+  const int row = blk_pos >> bwl;
+  const int col = blk_pos - (row << bwl);
+  const bool dc_2dtx = (blk_pos == 0);
+  const bool dc_hor = (col == 0) && tx_class == TX_CLASS_HORIZ;
+  const bool dc_ver = (row == 0) && tx_class == TX_CLASS_VERT;
+  const bool is_dc_coeff = dc_2dtx || dc_hor || dc_ver;
+  if (is_dc_coeff) {
+    for (int i = 0; i < n_states; i++) {
+      int a0 = i & 2 ? 1 : 0;
+      int a1 = a0 + 2;
+      int mid_cost0 = get_mid_cost_lf_dc(blk_pos, absLevel[a0], coeff_sign,
+                                         coeff_ctx->coef[i], dc_sign_ctx,
+                                         txb_costs, tmp_sign, plane);
+      int mid_cost1 = get_mid_cost_lf_dc(blk_pos, absLevel[a1], coeff_sign,
+                                         coeff_ctx->coef[i], dc_sign_ctx,
+                                         txb_costs, tmp_sign, plane);
+      rd->rate[2 * i] += mid_cost0;
+      rd->rate[2 * i + 1] += mid_cost1;
+    }
+    int t_sign = tmp_sign[blk_pos];
+    int eob_mid_cost0 =
+        get_mid_cost_eob(blk_pos, 1, 1, absLevel[0], coeff_sign, dc_sign_ctx,
+                         txb_costs, tx_class, t_sign, 0);
+    int eob_mid_cost1 =
+        get_mid_cost_eob(blk_pos, 1, 1, absLevel[2], coeff_sign, dc_sign_ctx,
+                         txb_costs, tx_class, t_sign, 0);
+    rd->rate_eob[0] += eob_mid_cost0;
+    rd->rate_eob[1] += eob_mid_cost1;
+  } else if (idx > 4) {
+    for (int i = 0; i < n_states; i++) {
       int a0 = i & 2 ? 1 : 0;
       int a1 = a0 + 2;
       int mid_cost0 =
@@ -728,9 +1129,10 @@
                                       const struct prequant_t *pq,
                                       const struct tcq_coeff_ctx_t *coeff_ctx,
                                       int blk_pos, int diag_ctx, int eob_rate,
-                                      int dc_sign_ctx, int32_t *tmp_sign,
+                                      int dc_sign_ctx, const int32_t *tmp_sign,
                                       int bwl, TX_CLASS tx_class, int plane,
-                                      int coeff_sign, struct tcq_rate_t *rd) {
+                                      int coeff_sign, int n_states,
+                                      struct tcq_rate_t *rd) {
   (void)bwl;
 #define Z -1
   static const int8_t kShuf[2][32] = {
@@ -768,18 +1170,14 @@
   ratez = _mm256_permute4x64_epi64(ratez, 0x88);
   __m256i shuf1 = _mm256_lddqu_si256((__m256i *)kShuf[1]);
   ratez = _mm256_shuffle_epi8(ratez, shuf1);
-#if MORESTATES
   _mm256_storeu_si256((__m256i *)&rd->rate_zero[0], ratez);
-#else
-  _mm_storeu_si128((__m128i *)&rd->rate_zero[0], _mm256_castsi256_si128(ratez));
-#endif
 
   // Calc coeff_base rate.
   int idx = AOMMIN(pq->qIdx - 1, 8);
   __m128i c_zero = _mm_setzero_si128();
   __m256i diag = _mm256_set1_epi8(diag_ctx);
   base_ctx = _mm256_add_epi8(base_ctx, diag);
-  for (int i = 0; i < TOTALSTATES / 4; i++) {
+  for (int i = 0; i < (n_states >> 2); i++) {
     int ctx0 = _mm256_extract_epi8(base_ctx, 0);
     int ctx1 = _mm256_extract_epi8(base_ctx, 1);
     int ctx2 = _mm256_extract_epi8(base_ctx, 2);
@@ -817,7 +1215,7 @@
   const bool is_dc_coeff = dc_2dtx || dc_hor || dc_ver;
 #endif
   if (is_dc_coeff) {
-    for (int i = 0; i < TOTALSTATES; i++) {
+    for (int i = 0; i < n_states; i++) {
       int a0 = i & 2 ? 1 : 0;
       int a1 = a0 + 2;
       int mid_cost0 = get_mid_cost_lf_dc(blk_pos, absLevel[a0], coeff_sign,
@@ -839,7 +1237,7 @@
     rd->rate_eob[0] += eob_mid_cost0;
     rd->rate_eob[1] += eob_mid_cost1;
   } else if (idx > 4) {
-    for (int i = 0; i < TOTALSTATES; i++) {
+    for (int i = 0; i < n_states; i++) {
       int a0 = i & 2 ? 1 : 0;
       int a1 = a0 + 2;
       int mid_cost0 =
@@ -865,7 +1263,7 @@
     const struct LV_MAP_COEFF_COST *txb_costs, const struct prequant_t *pq,
     const struct tcq_coeff_ctx_t *coeff_ctx, int blk_pos, int bwl,
     TX_CLASS tx_class, int diag_ctx, int eob_rate, int plane, int t_sign,
-    int sign, struct tcq_rate_t *rd) {
+    int sign, int n_states, struct tcq_rate_t *rd) {
   (void)bwl;
   const int32_t(*cost_zero)[SIG_COEF_CONTEXTS] = txb_costs->base_cost_uv_zero;
   const uint16_t(*cost_low_tbl)[SIG_COEF_CONTEXTS][DQ_CTXS][2] =
@@ -890,11 +1288,9 @@
   __m256i ratez_0123 = _mm256_unpacklo_epi64(ratez_dq0, ratez_dq1);
   _mm_storeu_si128((__m128i *)&rd->rate_zero[0],
                    _mm256_castsi256_si128(ratez_0123));
-#if MORESTATES
   __m256i ratez_4567 = _mm256_unpackhi_epi64(ratez_dq0, ratez_dq1);
   _mm_storeu_si128((__m128i *)&rd->rate_zero[4],
                    _mm256_castsi256_si128(ratez_4567));
-#endif
 
   // Calc coeff_base rate.
   int idx = AOMMIN(pq->qIdx - 1, 4);
@@ -903,7 +1299,7 @@
   __m256i base_ctx = _mm256_slli_epi16(ctx16, 12);
   base_ctx = _mm256_srli_epi16(base_ctx, 12);
   base_ctx = _mm256_add_epi16(base_ctx, diag);
-  for (int i = 0; i < TOTALSTATES / 4; i++) {
+  for (int i = 0; i < (n_states >> 2); i++) {
     int ctx0 = _mm256_extract_epi16(base_ctx, 0);
     int ctx1 = _mm256_extract_epi16(base_ctx, 1);
     int ctx2 = _mm256_extract_epi16(base_ctx, 2);
@@ -931,7 +1327,7 @@
 
   // Calc coeff mid and high range cost.
   if (idx > 0 || plane) {
-    for (int i = 0; i < TOTALSTATES; i++) {
+    for (int i = 0; i < n_states; i++) {
       int a0 = i & 2 ? 1 : 0;
       int a1 = a0 + 2;
       int mid_cost0 = get_mid_cost_def(absLevel[a0], coeff_ctx->coef[i],
@@ -1085,17 +1481,18 @@
   return dqv;
 }
 
-int av1_find_best_path_avx2(const struct tcq_node_t *trellis,
+int av1_find_best_path_avx2(const struct tcq_node_t *trellis, int n_states_log2,
                             const int16_t *scan, const int32_t *dequant,
                             const qm_val_t *iqmatrix, const tran_low_t *tcoeff,
                             int first_scan_pos, int log_scale,
                             tran_low_t *qcoeff, tran_low_t *dqcoeff,
                             int *min_rate, int64_t *min_cost) {
   // Select best trellis state.
+  int n_states = 1 << n_states_log2;
   int64_t min_path_cost = INT64_MAX;
   int trel_min_rate = INT32_MAX;
   int prev_id = -2;
-  for (int state = 0; state < TOTALSTATES; state++) {
+  for (int state = 0; state < n_states; state++) {
     const tcq_node_t *decision = &trellis[state];
     if (decision->rdCost < min_path_cost) {
       prev_id = state;
@@ -1114,7 +1511,7 @@
     int shift = QUANT_TABLE_BITS + log_scale;
     for (; prev_id >= 0; scan_pos++) {
       const int32_t *decision =
-          (int32_t *)&trellis[scan_pos * TOTALSTATES + prev_id];
+          (int32_t *)&trellis[(scan_pos << n_states_log2) + prev_id];
       __m128i info = _mm_loadu_si64(&decision[3]);
       int blk_pos = scan[scan_pos];
       __m128i sign = _mm_loadu_si64(&tcoeff[blk_pos]);
@@ -1149,7 +1546,8 @@
     }
   } else {
     for (; prev_id >= 0; scan_pos++) {
-      const tcq_node_t *decision = &trellis[scan_pos * TOTALSTATES + prev_id];
+      const tcq_node_t *decision =
+          &trellis[(scan_pos << n_states_log2) + prev_id];
       prev_id = decision->prevId;
       int abs_level = decision->absLevel;
       int blk_pos = scan[scan_pos];