diff --git a/av1/encoder/av1_quantize.h b/av1/encoder/av1_quantize.h
index d4fa0b2..4b306d9 100644
--- a/av1/encoder/av1_quantize.h
+++ b/av1/encoder/av1_quantize.h
@@ -31,6 +31,8 @@
   const qm_val_t *qmatrix;
   const qm_val_t *iqmatrix;
   int use_quant_b_adapt;
+  int use_optimize_b;
+  int xform_quant_idx;
 } QUANT_PARAM;
 
 typedef void (*AV1_QUANT_FACADE)(const tran_low_t *coeff_ptr, intptr_t n_coeffs,
diff --git a/av1/encoder/encodemb.c b/av1/encoder/encodemb.c
index 143df94..9e87527 100644
--- a/av1/encoder/encodemb.c
+++ b/av1/encoder/encodemb.c
@@ -126,8 +126,7 @@
 #endif
 
 void av1_xform_quant(MACROBLOCK *x, int plane, int block, int blk_row,
-                     int blk_col, BLOCK_SIZE plane_bsize,
-                     AV1_XFORM_QUANT xform_quant_idx, TxfmParam *txfm_param,
+                     int blk_col, BLOCK_SIZE plane_bsize, TxfmParam *txfm_param,
                      QUANT_PARAM *qparam) {
   MACROBLOCKD *const xd = &x->e_mbd;
   const struct macroblock_plane *const p = &x->plane[plane];
@@ -146,30 +145,27 @@
 
   av1_fwd_txfm(src_diff, coeff, diff_stride, txfm_param);
 
-  if (xform_quant_idx != AV1_XFORM_QUANT_SKIP_QUANT) {
+  if (qparam->xform_quant_idx != AV1_XFORM_QUANT_SKIP_QUANT) {
     const int n_coeffs = av1_get_max_eob(txfm_param->tx_size);
     if (LIKELY(!x->skip_block)) {
 #if CONFIG_AV1_HIGHBITDEPTH
-      quant_func_list[xform_quant_idx][txfm_param->is_hbd](
+      quant_func_list[qparam->xform_quant_idx][txfm_param->is_hbd](
           coeff, n_coeffs, p, qcoeff, dqcoeff, eob, scan_order, qparam);
 #else
-      quant_func_list[xform_quant_idx](coeff, n_coeffs, p, qcoeff, dqcoeff, eob,
-                                       scan_order, qparam);
+      quant_func_list[qparam->xform_quant_idx](
+          coeff, n_coeffs, p, qcoeff, dqcoeff, eob, scan_order, qparam);
 #endif
     } else {
       av1_quantize_skip(n_coeffs, qcoeff, dqcoeff, eob);
     }
   }
-  // NOTE: optimize_b_following is true means av1_optimze_b will be called
-  // When the condition of doing optimize_b is changed,
-  // this flag need update simultaneously
-  const int optimize_b_following =
-      (xform_quant_idx != AV1_XFORM_QUANT_FP) || (txfm_param->lossless);
-  if (optimize_b_following) {
+  // use_optimize_b is true means av1_optimze_b will be called,
+  // thus cannot update entropy ctx now (performed in optimize_b)
+  if (qparam->use_optimize_b) {
+    p->txb_entropy_ctx[block] = 0;
+  } else {
     p->txb_entropy_ctx[block] =
         (uint8_t)av1_get_txb_entropy_context(qcoeff, scan_order, *eob);
-  } else {
-    p->txb_entropy_ctx[block] = 0;
   }
   return;
 }
@@ -188,12 +184,20 @@
   txfm_param->bd = xd->bd;
   txfm_param->is_hbd = is_cur_buf_hbd(xd);
 }
-void av1_setup_quant(const AV1_COMMON *cm, TX_SIZE tx_size,
-                     QUANT_PARAM *qparam) {
+void av1_setup_quant(const AV1_COMMON *cm, TX_SIZE tx_size, int use_optimize_b,
+                     int xform_quant_idx, QUANT_PARAM *qparam) {
   qparam->log_scale = av1_get_tx_scale(tx_size);
   qparam->tx_size = tx_size;
 
   qparam->use_quant_b_adapt = cm->use_quant_b_adapt;
+
+  // TODO(bohanli): optimize_b and quantization idx has relationship,
+  // but is kind of buried and complicated in different encoding stages.
+  // Should have a unified function to derive quant_idx, rather than
+  // determine and pass in the quant_idx
+  qparam->use_optimize_b = use_optimize_b;
+  qparam->xform_quant_idx = xform_quant_idx;
+
   qparam->qmatrix = NULL;
   qparam->iqmatrix = NULL;
 }
@@ -247,25 +251,25 @@
                               cm->reduced_tx_set_used);
     TxfmParam txfm_param;
     QUANT_PARAM quant_param;
+    int use_trellis = (args->enable_optimize_b != NO_TRELLIS_OPT);
+    int quant_idx;
+    if (use_trellis && args->enable_optimize_b != FINAL_PASS_TRELLIS_OPT) {
+      quant_idx = AV1_XFORM_QUANT_FP;
+    } else {
+      quant_idx =
+          USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B : AV1_XFORM_QUANT_FP;
+    }
     av1_setup_xform(cm, x, tx_size, tx_type, &txfm_param);
-    av1_setup_quant(cm, tx_size, &quant_param);
+    av1_setup_quant(cm, tx_size, use_trellis, quant_idx, &quant_param);
     av1_setup_qmatrix(cm, x, plane, tx_size, tx_type, &quant_param);
-    if (args->enable_optimize_b != NO_TRELLIS_OPT) {
-      av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize,
-                      USE_B_QUANT_NO_TRELLIS && (args->enable_optimize_b ==
-                                                 FINAL_PASS_TRELLIS_OPT)
-                          ? AV1_XFORM_QUANT_B
-                          : AV1_XFORM_QUANT_FP,
-                      &txfm_param, &quant_param);
+    av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
+                    &quant_param);
+
+    if (quant_param.use_optimize_b) {
       TXB_CTX txb_ctx;
       get_txb_ctx(plane_bsize, tx_size, plane, a, l, &txb_ctx);
       av1_optimize_b(args->cpi, x, plane, block, tx_size, tx_type, &txb_ctx,
                      args->cpi->sf.trellis_eob_fast, &dummy_rate_cost);
-    } else {
-      av1_xform_quant(
-          x, plane, block, blk_row, blk_col, plane_bsize,
-          USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B : AV1_XFORM_QUANT_FP,
-          &txfm_param, &quant_param);
     }
   } else {
     p->eobs[block] = 0;
@@ -456,11 +460,11 @@
   QUANT_PARAM quant_param;
 
   av1_setup_xform(cm, x, tx_size, DCT_DCT, &txfm_param);
-  av1_setup_quant(cm, tx_size, &quant_param);
+  av1_setup_quant(cm, tx_size, 0, AV1_XFORM_QUANT_B, &quant_param);
   av1_setup_qmatrix(cm, x, plane, tx_size, DCT_DCT, &quant_param);
 
-  av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize,
-                  AV1_XFORM_QUANT_B, &txfm_param, &quant_param);
+  av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
+                  &quant_param);
 
   if (p->eobs[block] > 0) {
     txfm_param.eob = p->eobs[block];
@@ -608,25 +612,26 @@
                               cm->reduced_tx_set_used);
     TxfmParam txfm_param;
     QUANT_PARAM quant_param;
+    int use_trellis = args->enable_optimize_b != NO_TRELLIS_OPT;
+    int quant_idx;
+    if (use_trellis && args->enable_optimize_b != FINAL_PASS_TRELLIS_OPT)
+      quant_idx = AV1_XFORM_QUANT_FP;
+    else
+      quant_idx =
+          USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B : AV1_XFORM_QUANT_FP;
+
     av1_setup_xform(cm, x, tx_size, tx_type, &txfm_param);
-    av1_setup_quant(cm, tx_size, &quant_param);
+    av1_setup_quant(cm, tx_size, use_trellis, quant_idx, &quant_param);
     av1_setup_qmatrix(cm, x, plane, tx_size, tx_type, &quant_param);
-    if (args->enable_optimize_b != NO_TRELLIS_OPT) {
-      av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize,
-                      USE_B_QUANT_NO_TRELLIS && (args->enable_optimize_b ==
-                                                 FINAL_PASS_TRELLIS_OPT)
-                          ? AV1_XFORM_QUANT_B
-                          : AV1_XFORM_QUANT_FP,
-                      &txfm_param, &quant_param);
+
+    av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
+                    &quant_param);
+
+    if (quant_param.use_optimize_b) {
       TXB_CTX txb_ctx;
       get_txb_ctx(plane_bsize, tx_size, plane, a, l, &txb_ctx);
       av1_optimize_b(args->cpi, x, plane, block, tx_size, tx_type, &txb_ctx,
                      args->cpi->sf.trellis_eob_fast, &dummy_rate_cost);
-    } else {
-      av1_xform_quant(
-          x, plane, block, blk_row, blk_col, plane_bsize,
-          USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B : AV1_XFORM_QUANT_FP,
-          &txfm_param, &quant_param);
     }
   }
 
diff --git a/av1/encoder/encodemb.h b/av1/encoder/encodemb.h
index 0a38859..dc9b190 100644
--- a/av1/encoder/encodemb.h
+++ b/av1/encoder/encodemb.h
@@ -61,14 +61,13 @@
 
 void av1_setup_xform(const AV1_COMMON *cm, MACROBLOCK *x, TX_SIZE tx_size,
                      TX_TYPE tx_type, TxfmParam *txfm_param);
-void av1_setup_quant(const AV1_COMMON *cm, TX_SIZE tx_size,
-                     QUANT_PARAM *qparam);
+void av1_setup_quant(const AV1_COMMON *cm, TX_SIZE tx_size, int use_optimize_b,
+                     int xform_quant_idx, QUANT_PARAM *qparam);
 void av1_setup_qmatrix(const AV1_COMMON *cm, MACROBLOCK *x, int plane,
                        TX_SIZE tx_size, TX_TYPE tx_type, QUANT_PARAM *qparam);
 
 void av1_xform_quant(MACROBLOCK *x, int plane, int block, int blk_row,
-                     int blk_col, BLOCK_SIZE plane_bsize,
-                     AV1_XFORM_QUANT xform_quant_idx, TxfmParam *txfm_param,
+                     int blk_col, BLOCK_SIZE plane_bsize, TxfmParam *txfm_param,
                      QUANT_PARAM *qparam);
 
 int av1_optimize_b(const struct AV1_COMP *cpi, MACROBLOCK *mb, int plane,
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index c9b7eed..7b8f9e0 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -3314,6 +3314,7 @@
         best_eob = intra_txb_rd_info->eob;
         best_tx_type = intra_txb_rd_info->tx_type;
         perform_block_coeff_opt = intra_txb_rd_info->perform_block_coeff_opt;
+        skip_trellis |= !perform_block_coeff_opt;
         update_txk_array(xd, blk_row, blk_col, tx_size, best_tx_type);
         goto RECON_INTRA;
       }
@@ -3458,13 +3459,18 @@
   // larger residuals, R-D optimization may not be effective.
   // TODO(any): Experiment with variance and mean based thresholds
   perform_block_coeff_opt = (block_mse_q8 <= x->coeff_opt_dist_threshold);
+  skip_trellis |= !perform_block_coeff_opt;
 
   assert(IMPLIES(txk_allowed < TX_TYPES, allowed_tx_mask == 1 << txk_allowed));
 
   TxfmParam txfm_param;
   QUANT_PARAM quant_param;
   av1_setup_xform(cm, x, tx_size, DCT_DCT, &txfm_param);
-  av1_setup_quant(cm, tx_size, &quant_param);
+  av1_setup_quant(cm, tx_size, !skip_trellis,
+                  skip_trellis ? (USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B
+                                                         : AV1_XFORM_QUANT_FP)
+                               : AV1_XFORM_QUANT_FP,
+                  &quant_param);
   int use_qm = !(xd->lossless[mbmi->segment_id] || cm->using_qmatrix == 0);
 
   for (int idx = 0; idx < TX_TYPES; ++idx) {
@@ -3477,17 +3483,11 @@
     if (plane == 0) xd->tx_type_map[tx_type_map_idx] = tx_type;
     RD_STATS this_rd_stats;
     av1_invalid_rd_stats(&this_rd_stats);
-    if (skip_trellis || (!perform_block_coeff_opt)) {
-      av1_xform_quant(
-          x, plane, block, blk_row, blk_col, plane_bsize,
-          USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B : AV1_XFORM_QUANT_FP,
-          &txfm_param, &quant_param);
-      rate_cost =
-          av1_cost_coeffs(x, plane, block, tx_size, tx_type, txb_ctx,
-                          use_fast_coef_costing, cm->reduced_tx_set_used);
-    } else {
-      av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize,
-                      AV1_XFORM_QUANT_FP, &txfm_param, &quant_param);
+
+    av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
+                    &quant_param);
+
+    if (quant_param.use_optimize_b) {
       if (cpi->sf.optimize_b_precheck && best_rd < INT64_MAX &&
           eobs_ptr[block] >= 4) {
         // Calculate distortion quickly in transform domain.
@@ -3501,7 +3501,12 @@
       }
       av1_optimize_b(cpi, x, plane, block, tx_size, tx_type, txb_ctx,
                      cpi->sf.trellis_eob_fast, &rate_cost);
+    } else {
+      rate_cost =
+          av1_cost_coeffs(x, plane, block, tx_size, tx_type, txb_ctx,
+                          use_fast_coef_costing, cm->reduced_tx_set_used);
     }
+
     // If rd cost based on coeff rate is more than best_rd, skip the calculation
     // of distortion
     int64_t tmp_rd = RDCOST(x->rdmult, rate_cost, 0);
@@ -3660,18 +3665,17 @@
       TxfmParam txfm_param_intra;
       QUANT_PARAM quant_param_intra;
       av1_setup_xform(cm, x, tx_size, best_tx_type, &txfm_param_intra);
-      av1_setup_quant(cm, tx_size, &quant_param_intra);
+      av1_setup_quant(cm, tx_size, !skip_trellis,
+                      skip_trellis
+                          ? (USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B
+                                                    : AV1_XFORM_QUANT_FP)
+                          : AV1_XFORM_QUANT_FP,
+                      &quant_param_intra);
       av1_setup_qmatrix(cm, x, plane, tx_size, best_tx_type,
                         &quant_param_intra);
-      if (skip_trellis || (!perform_block_coeff_opt)) {
-        av1_xform_quant(
-            x, plane, block, blk_row, blk_col, plane_bsize,
-            USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B : AV1_XFORM_QUANT_FP,
-            &txfm_param_intra, &quant_param_intra);
-      } else {
-        av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize,
-                        AV1_XFORM_QUANT_FP, &txfm_param_intra,
-                        &quant_param_intra);
+      av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize,
+                      &txfm_param_intra, &quant_param_intra);
+      if (quant_param_intra.use_optimize_b) {
         av1_optimize_b(cpi, x, plane, block, tx_size, best_tx_type, txb_ctx,
                        cpi->sf.trellis_eob_fast, &rate_cost);
       }
