Refactor av1_optimize_txb(_new)

Tested this with city_cif.y4m with limit=2 bitrate=1000 with speed 1
Encoder was sped up by 5%

Coding performance changes by 0.026% on lowres limit=30

STATS_CHANGED

Change-Id: I96347e94f6b6ff74072c8198d023f1947db91b8a
diff --git a/av1/common/txb_common.h b/av1/common/txb_common.h
index ee41156..5705b9c 100644
--- a/av1/common/txb_common.h
+++ b/av1/common/txb_common.h
@@ -583,6 +583,21 @@
   return get_nz_map_ctx_from_stats(stats, coeff_idx, bwl, tx_size, tx_class);
 }
 
+static INLINE int get_lower_levels_ctx_general(int is_last, int scan_idx,
+                                               int bwl, int height,
+                                               const uint8_t *levels,
+                                               int coeff_idx, TX_SIZE tx_size,
+                                               TX_TYPE tx_type) {
+  if (is_last) {
+    if (scan_idx == 0) return 0;
+    if (scan_idx <= (height << bwl) >> 3) return 1;
+    if (scan_idx <= (height << bwl) >> 2) return 2;
+    return 3;
+  } else {
+    return get_lower_levels_ctx(levels, coeff_idx, bwl, tx_size, tx_type);
+  }
+}
+
 static INLINE void set_dc_sign(int *cul_level, tran_low_t v) {
   if (v < 0)
     *cul_level |= 1 << COEFF_CONTEXT_BITS;
diff --git a/av1/encoder/encodemb.c b/av1/encoder/encodemb.c
index e0e54be..522e5d0 100644
--- a/av1/encoder/encodemb.c
+++ b/av1/encoder/encodemb.c
@@ -112,8 +112,9 @@
     return eob;
   }
 
-  return av1_optimize_txb(cpi, mb, plane, blk_row, blk_col, block, tx_size,
-                          &txb_ctx, fast_mode, rate_cost);
+  (void)fast_mode;
+  return av1_optimize_txb_new(cpi, mb, plane, blk_row, blk_col, block, tx_size,
+                              &txb_ctx, rate_cost);
 }
 
 typedef enum QUANT_FUNC {
diff --git a/av1/encoder/encodetxb.c b/av1/encoder/encodetxb.c
index 82432b0..74b2fc4 100644
--- a/av1/encoder/encodetxb.c
+++ b/av1/encoder/encodetxb.c
@@ -240,10 +240,71 @@
   return eob_cost;
 }
 
+static INLINE int get_sign_bit_cost(tran_low_t qc, int coeff_idx,
+                                    const int (*dc_sign_cost)[2],
+                                    int dc_sign_ctx) {
+  const int sign = (qc < 0) ? 1 : 0;
+  // sign bit cost
+  if (coeff_idx == 0) {
+    return dc_sign_cost[dc_sign_ctx][sign];
+  } else {
+    return av1_cost_literal(1);
+  }
+}
+
+static INLINE int get_br_cost(tran_low_t abs_qc, int ctx,
+                              const int *coeff_lps) {
+  const tran_low_t min_level = 1 + NUM_BASE_LEVELS;
+  const tran_low_t max_level = 1 + NUM_BASE_LEVELS + COEFF_BASE_RANGE;
+  (void)ctx;
+  if (abs_qc >= min_level) {
+    if (abs_qc >= max_level)
+      return coeff_lps[COEFF_BASE_RANGE];  // COEFF_BASE_RANGE * cost0;
+    else
+      return coeff_lps[(abs_qc - min_level)];  //  * cost0 + cost1;
+  } else {
+    return 0;
+  }
+}
+
+static INLINE int get_golomb_cost(int abs_qc) {
+  if (abs_qc >= 1 + NUM_BASE_LEVELS + COEFF_BASE_RANGE) {
+    const int r = abs_qc - COEFF_BASE_RANGE - NUM_BASE_LEVELS;
+    const int length = get_msb(r) + 1;
+    return av1_cost_literal(2 * length - 1);
+  }
+  return 0;
+}
+
 static int get_coeff_cost(const tran_low_t qc, const int scan_idx,
                           const int is_eob, const TxbInfo *const txb_info,
                           const LV_MAP_COEFF_COST *const txb_costs,
-                          const int coeff_ctx);
+                          const int coeff_ctx) {
+  const TXB_CTX *txb_ctx = txb_info->txb_ctx;
+  const int is_nz = (qc != 0);
+  const tran_low_t abs_qc = abs(qc);
+  int cost = 0;
+  const int16_t *const scan = txb_info->scan_order->scan;
+  const int pos = scan[scan_idx];
+
+  if (is_eob) {
+    cost += txb_costs->base_eob_cost[coeff_ctx][AOMMIN(abs_qc, 3) - 1];
+  } else {
+    cost += txb_costs->base_cost[coeff_ctx][AOMMIN(abs_qc, 3)];
+  }
+  if (is_nz) {
+    cost += get_sign_bit_cost(qc, scan_idx, txb_costs->dc_sign_cost,
+                              txb_ctx->dc_sign_ctx);
+
+    if (abs_qc > NUM_BASE_LEVELS) {
+      const int ctx =
+          get_br_ctx(txb_info->levels, pos, txb_info->bwl, txb_info->tx_type);
+      cost += get_br_cost(abs_qc, ctx, txb_costs->lps_cost[ctx]);
+      cost += get_golomb_cost(abs_qc);
+    }
+  }
+  return cost;
+}
 
 static INLINE int get_nz_map_ctx(const uint8_t *const levels,
                                  const int coeff_idx, const int bwl,
@@ -593,30 +654,6 @@
   }
 }
 
-static INLINE int get_br_cost(tran_low_t abs_qc, int ctx,
-                              const int *coeff_lps) {
-  const tran_low_t min_level = 1 + NUM_BASE_LEVELS;
-  const tran_low_t max_level = 1 + NUM_BASE_LEVELS + COEFF_BASE_RANGE;
-  (void)ctx;
-  if (abs_qc >= min_level) {
-    if (abs_qc >= max_level)
-      return coeff_lps[COEFF_BASE_RANGE];  // COEFF_BASE_RANGE * cost0;
-    else
-      return coeff_lps[(abs_qc - min_level)];  //  * cost0 + cost1;
-  } else {
-    return 0;
-  }
-}
-
-static INLINE int get_golomb_cost(int abs_qc) {
-  if (abs_qc >= 1 + NUM_BASE_LEVELS + COEFF_BASE_RANGE) {
-    const int r = abs_qc - COEFF_BASE_RANGE - NUM_BASE_LEVELS;
-    const int length = get_msb(r) + 1;
-    return av1_cost_literal(2 * length - 1);
-  }
-  return 0;
-}
-
 // Note: don't call this function when eob is 0.
 int av1_cost_coeffs_txb(const AV1_COMMON *const cm, const MACROBLOCK *x,
                         const int plane, const int blk_row, const int blk_col,
@@ -646,8 +683,9 @@
   const int eob_multi_size = txsize_log2_minus4[tx_size];
   const LV_MAP_EOB_COST *const eob_costs =
       &x->eob_costs[eob_multi_size][plane_type];
-  // eob must be greater than 0 here.
-  assert(eob > 0);
+  if (eob == 0) {
+    return coeff_costs->txb_skip_cost[txb_ctx->txb_skip_ctx][1];
+  }
   cost = coeff_costs->txb_skip_cost[txb_skip_ctx][0];
 
   av1_txb_init_levels(qcoeff, width, height, levels);
@@ -709,53 +747,11 @@
   return abs(qc) >= 1 + NUM_BASE_LEVELS;
 }
 
-static INLINE int get_sign_bit_cost(tran_low_t qc, int coeff_idx,
-                                    const int (*dc_sign_cost)[2],
-                                    int dc_sign_ctx) {
-  const int sign = (qc < 0) ? 1 : 0;
-  // sign bit cost
-  if (coeff_idx == 0) {
-    return dc_sign_cost[dc_sign_ctx][sign];
-  } else {
-    return av1_cost_literal(1);
-  }
-}
-
 static INLINE void set_eob(TxbInfo *txb_info, int eob) {
   txb_info->eob = eob;
   txb_info->seg_eob = av1_get_max_eob(txb_info->tx_size);
 }
 
-static int get_coeff_cost(const tran_low_t qc, const int scan_idx,
-                          const int is_eob, const TxbInfo *const txb_info,
-                          const LV_MAP_COEFF_COST *const txb_costs,
-                          const int coeff_ctx) {
-  const TXB_CTX *txb_ctx = txb_info->txb_ctx;
-  const int is_nz = (qc != 0);
-  const tran_low_t abs_qc = abs(qc);
-  int cost = 0;
-  const int16_t *const scan = txb_info->scan_order->scan;
-  const int pos = scan[scan_idx];
-
-  if (is_eob) {
-    cost += txb_costs->base_eob_cost[coeff_ctx][AOMMIN(abs_qc, 3) - 1];
-  } else {
-    cost += txb_costs->base_cost[coeff_ctx][AOMMIN(abs_qc, 3)];
-  }
-  if (is_nz) {
-    cost += get_sign_bit_cost(qc, scan_idx, txb_costs->dc_sign_cost,
-                              txb_ctx->dc_sign_ctx);
-
-    if (abs_qc > NUM_BASE_LEVELS) {
-      const int ctx =
-          get_br_ctx(txb_info->levels, pos, txb_info->bwl, txb_info->tx_type);
-      cost += get_br_cost(abs_qc, ctx, txb_costs->lps_cost[ctx]);
-      cost += get_golomb_cost(abs_qc);
-    }
-  }
-  return cost;
-}
-
 static int optimize_txb(TxbInfo *txb_info, const LV_MAP_COEFF_COST *txb_costs,
                         const LV_MAP_EOB_COST *txb_eob_costs, int *rate_cost) {
   int update = 0;
@@ -1193,6 +1189,403 @@
                           txb_eob_costs, p, block, fast_mode, rate_cost);
 }
 
+static INLINE int get_coeff_cost_simple(int ci, tran_low_t abs_qc,
+                                        int coeff_ctx,
+                                        const LV_MAP_COEFF_COST *txb_costs,
+                                        int bwl, TX_TYPE tx_type,
+                                        const uint8_t *levels) {
+  // this simple version assume the coeff's scan_idx is not DC (scan_idx != 0)
+  // and not the last (scan_idx != eob - 1)
+  assert(ci > 0);
+  int cost = txb_costs->base_cost[coeff_ctx][AOMMIN(abs_qc, 3)];
+  if (abs_qc) {
+    cost += av1_cost_literal(1);
+    if (abs_qc > NUM_BASE_LEVELS) {
+      const int br_ctx = get_br_ctx(levels, ci, bwl, tx_type);
+      cost += get_br_cost(abs_qc, br_ctx, txb_costs->lps_cost[br_ctx]);
+      cost += get_golomb_cost(abs_qc);
+    }
+  }
+  return cost;
+}
+
+static INLINE int get_coeff_cost_general(int is_last, int ci, tran_low_t abs_qc,
+                                         int sgn, int coeff_ctx,
+                                         int dc_sign_ctx,
+                                         const LV_MAP_COEFF_COST *txb_costs,
+                                         int bwl, TX_TYPE tx_type,
+                                         const uint8_t *levels) {
+  int cost = 0;
+  if (is_last) {
+    cost += txb_costs->base_eob_cost[coeff_ctx][AOMMIN(abs_qc, 3) - 1];
+  } else {
+    cost += txb_costs->base_cost[coeff_ctx][AOMMIN(abs_qc, 3)];
+  }
+  if (abs_qc != 0) {
+    if (ci == 0) {
+      cost += txb_costs->dc_sign_cost[dc_sign_ctx][sgn];
+    } else {
+      cost += av1_cost_literal(1);
+    }
+    if (abs_qc > NUM_BASE_LEVELS) {
+      const int br_ctx = get_br_ctx(levels, ci, bwl, tx_type);
+      cost += get_br_cost(abs_qc, br_ctx, txb_costs->lps_cost[br_ctx]);
+      cost += get_golomb_cost(abs_qc);
+    }
+  }
+  return cost;
+}
+
+static INLINE void get_qc_dqc_low(tran_low_t abs_qc, int sgn, int dqv,
+                                  int shift, tran_low_t *qc_low,
+                                  tran_low_t *dqc_low) {
+  tran_low_t abs_qc_low = abs_qc - 1;
+  *qc_low = (-sgn ^ abs_qc_low) + sgn;
+  assert((sgn ? -abs_qc_low : abs_qc_low) == *qc_low);
+  tran_low_t abs_dqc_low = (abs_qc_low * dqv) >> shift;
+  *dqc_low = (-sgn ^ abs_dqc_low) + sgn;
+  assert((sgn ? -abs_dqc_low : abs_dqc_low) == *dqc_low);
+}
+
+static INLINE void update_coeff_general(
+    int *accu_rate, int64_t *accu_dist, int si, int eob, TX_SIZE tx_size,
+    TX_TYPE tx_type, int bwl, int height, int64_t rdmult, int shift,
+    int dc_sign_ctx, const int16_t *dequant, const int16_t *scan,
+    const LV_MAP_COEFF_COST *txb_costs, const tran_low_t *tcoeff,
+    tran_low_t *qcoeff, tran_low_t *dqcoeff, uint8_t *levels) {
+  const int dqv = dequant[si != 0];
+  const int ci = scan[si];
+  const tran_low_t qc = qcoeff[ci];
+  const int is_last = si == (eob - 1);
+  const int coeff_ctx = get_lower_levels_ctx_general(
+      is_last, si, bwl, height, levels, ci, tx_size, tx_type);
+  if (qc == 0) {
+    *accu_rate += txb_costs->base_cost[coeff_ctx][0];
+  } else {
+    const int sgn = (qc < 0) ? 1 : 0;
+    const tran_low_t abs_qc = abs(qc);
+    const tran_low_t tqc = tcoeff[ci];
+    const tran_low_t dqc = dqcoeff[ci];
+    const int64_t dist = get_coeff_dist(tqc, dqc, shift);
+    const int64_t dist0 = get_coeff_dist(tqc, 0, shift);
+    const int rate =
+        get_coeff_cost_general(is_last, ci, abs_qc, sgn, coeff_ctx, dc_sign_ctx,
+                               txb_costs, bwl, tx_type, levels);
+    const int64_t rd = RDCOST(rdmult, rate, dist);
+
+    tran_low_t qc_low, dqc_low;
+    get_qc_dqc_low(abs_qc, sgn, dqv, shift, &qc_low, &dqc_low);
+    const tran_low_t abs_qc_low = abs_qc - 1;
+    const int64_t dist_low = get_coeff_dist(tqc, dqc_low, shift);
+    const int rate_low =
+        get_coeff_cost_general(is_last, ci, abs_qc_low, sgn, coeff_ctx,
+                               dc_sign_ctx, txb_costs, bwl, tx_type, levels);
+    const int64_t rd_low = RDCOST(rdmult, rate_low, dist_low);
+    if (rd_low < rd) {
+      qcoeff[ci] = qc_low;
+      dqcoeff[ci] = dqc_low;
+      levels[get_padded_idx(ci, bwl)] = AOMMIN(abs_qc_low, INT8_MAX);
+      *accu_rate += rate_low;
+      *accu_dist += dist_low - dist0;
+    } else {
+      *accu_rate += rate;
+      *accu_dist += dist - dist0;
+    }
+  }
+}
+
+static INLINE void update_coeff_simple(
+    int *accu_rate, int64_t *accu_dist, int si, int eob, TX_SIZE tx_size,
+    TX_TYPE tx_type, int bwl, int64_t rdmult, int shift, const int16_t *dequant,
+    const int16_t *scan, const LV_MAP_COEFF_COST *txb_costs,
+    const tran_low_t *tcoeff, tran_low_t *qcoeff, tran_low_t *dqcoeff,
+    uint8_t *levels) {
+  const int dqv = dequant[1];
+  (void)eob;
+  // this simple version assume the coeff's scan_idx is not DC (scan_idx != 0)
+  // and not the last (scan_idx != eob - 1)
+  assert(si != eob - 1);
+  assert(si > 0);
+  const int ci = scan[si];
+  const tran_low_t qc = qcoeff[ci];
+  const int coeff_ctx = get_lower_levels_ctx(levels, ci, bwl, tx_size, tx_type);
+  if (qc == 0) {
+    *accu_rate += txb_costs->base_cost[coeff_ctx][0];
+  } else {
+    const tran_low_t abs_qc = abs(qc);
+    const tran_low_t tqc = tcoeff[ci];
+    const tran_low_t dqc = dqcoeff[ci];
+    const int64_t dist = get_coeff_dist(tqc, dqc, shift);
+    const int64_t dist0 = get_coeff_dist(tqc, 0, shift);
+    const int rate = get_coeff_cost_simple(ci, abs_qc, coeff_ctx, txb_costs,
+                                           bwl, tx_type, levels);
+    const int64_t rd = RDCOST(rdmult, rate, dist);
+
+    const int sgn = (qc < 0) ? 1 : 0;
+    tran_low_t qc_low, dqc_low;
+    get_qc_dqc_low(abs_qc, sgn, dqv, shift, &qc_low, &dqc_low);
+    const tran_low_t abs_qc_low = abs_qc - 1;
+    const int64_t dist_low = get_coeff_dist(tqc, dqc_low, shift);
+    const int rate_low = get_coeff_cost_simple(ci, abs_qc_low, coeff_ctx,
+                                               txb_costs, bwl, tx_type, levels);
+    const int64_t rd_low = RDCOST(rdmult, rate_low, dist_low);
+    if (rd_low < rd) {
+      qcoeff[ci] = qc_low;
+      dqcoeff[ci] = dqc_low;
+      levels[get_padded_idx(ci, bwl)] = AOMMIN(abs_qc_low, INT8_MAX);
+      *accu_rate += rate_low;
+      *accu_dist += dist_low - dist0;
+    } else {
+      *accu_rate += rate;
+      *accu_dist += dist - dist0;
+    }
+  }
+}
+
+static INLINE void update_coeff_eob(
+    int *accu_rate, int64_t *accu_dist, int *eob, int *nz_num, int *nz_ci,
+    int si, TX_SIZE tx_size, TX_TYPE tx_type, int bwl, int height,
+    int dc_sign_ctx, int64_t rdmult, int shift, const int16_t *dequant,
+    const int16_t *scan, const LV_MAP_EOB_COST *txb_eob_costs,
+    const LV_MAP_COEFF_COST *txb_costs, const tran_low_t *tcoeff,
+    tran_low_t *qcoeff, tran_low_t *dqcoeff, uint8_t *levels) {
+  const int dqv = dequant[si != 0];
+  assert(si != *eob - 1);
+  const int ci = scan[si];
+  const tran_low_t qc = qcoeff[ci];
+  const int coeff_ctx = get_lower_levels_ctx(levels, ci, bwl, tx_size, tx_type);
+  if (qc == 0) {
+    *accu_rate += txb_costs->base_cost[coeff_ctx][0];
+  } else {
+    int lower_level = 0;
+    const tran_low_t abs_qc = abs(qc);
+    const tran_low_t tqc = tcoeff[ci];
+    const tran_low_t dqc = dqcoeff[ci];
+    const int sgn = (qc < 0) ? 1 : 0;
+    const int64_t dist0 = get_coeff_dist(tqc, 0, shift);
+    int64_t dist = get_coeff_dist(tqc, dqc, shift) - dist0;
+    int rate =
+        get_coeff_cost_general(0, ci, abs_qc, sgn, coeff_ctx, dc_sign_ctx,
+                               txb_costs, bwl, tx_type, levels);
+    int64_t rd = RDCOST(rdmult, *accu_rate + rate, *accu_dist + dist);
+
+    tran_low_t qc_low, dqc_low;
+    get_qc_dqc_low(abs_qc, sgn, dqv, shift, &qc_low, &dqc_low);
+    const tran_low_t abs_qc_low = abs_qc - 1;
+    const int64_t dist_low = get_coeff_dist(tqc, dqc_low, shift) - dist0;
+    const int rate_low =
+        get_coeff_cost_general(0, ci, abs_qc_low, sgn, coeff_ctx, dc_sign_ctx,
+                               txb_costs, bwl, tx_type, levels);
+    const int64_t rd_low =
+        RDCOST(rdmult, *accu_rate + rate_low, *accu_dist + dist_low);
+
+    int lower_level_new_eob = 0;
+    const int new_eob = si + 1;
+    uint8_t tmp_levels[3];
+    for (int ni = 0; ni < *nz_num; ++ni) {
+      const int last_ci = nz_ci[ni];
+      tmp_levels[ni] = levels[get_padded_idx(last_ci, bwl)];
+      levels[get_padded_idx(last_ci, bwl)] = 0;
+    }
+
+    const int coeff_ctx_new_eob = get_lower_levels_ctx_general(
+        1, si, bwl, height, levels, ci, tx_size, tx_type);
+    const int new_eob_cost =
+        get_eob_cost(new_eob, txb_eob_costs, txb_costs, tx_type);
+    int rate_coeff_eob =
+        new_eob_cost + get_coeff_cost_general(1, ci, abs_qc, sgn,
+                                              coeff_ctx_new_eob, dc_sign_ctx,
+                                              txb_costs, bwl, tx_type, levels);
+    int64_t dist_new_eob = dist;
+    int64_t rd_new_eob = RDCOST(rdmult, rate_coeff_eob, dist_new_eob);
+
+    if (abs_qc_low > 0) {
+      const int rate_coeff_eob_low =
+          new_eob_cost +
+          get_coeff_cost_general(1, ci, abs_qc_low, sgn, coeff_ctx_new_eob,
+                                 dc_sign_ctx, txb_costs, bwl, tx_type, levels);
+      const int64_t dist_new_eob_low = dist_low;
+      const int64_t rd_new_eob_low =
+          RDCOST(rdmult, rate_coeff_eob_low, dist_new_eob_low);
+      if (rd_new_eob_low < rd_new_eob) {
+        lower_level_new_eob = 1;
+        rd_new_eob = rd_new_eob_low;
+        rate_coeff_eob = rate_coeff_eob_low;
+        dist_new_eob = dist_new_eob_low;
+      }
+    }
+
+    if (rd_low < rd) {
+      lower_level = 1;
+      rd = rd_low;
+      rate = rate_low;
+      dist = dist_low;
+    }
+
+    if (rd_new_eob < rd) {
+      for (int ni = 0; ni < *nz_num; ++ni) {
+        int last_ci = nz_ci[ni];
+        // levels[get_padded_idx(last_ci, bwl)] = 0;
+        qcoeff[last_ci] = 0;
+        dqcoeff[last_ci] = 0;
+      }
+      *eob = new_eob;
+      *nz_num = 0;
+      *accu_rate = rate_coeff_eob;
+      *accu_dist = dist_new_eob;
+      lower_level = lower_level_new_eob;
+    } else {
+      for (int ni = 0; ni < *nz_num; ++ni) {
+        const int last_ci = nz_ci[ni];
+        levels[get_padded_idx(last_ci, bwl)] = tmp_levels[ni];
+      }
+      *accu_rate += rate;
+      *accu_dist += dist;
+    }
+
+    if (lower_level) {
+      qcoeff[ci] = qc_low;
+      dqcoeff[ci] = dqc_low;
+      levels[get_padded_idx(ci, bwl)] = AOMMIN(abs_qc_low, INT8_MAX);
+    }
+    if (qcoeff[ci]) {
+      nz_ci[*nz_num] = ci;
+      ++*nz_num;
+    }
+  }
+}
+
+static INLINE void update_skip(int *accu_rate, int64_t *accu_dist, int *eob,
+                               int64_t rdmult, int skip_cost, int non_skip_cost,
+                               const int16_t *scan, tran_low_t *qcoeff,
+                               tran_low_t *dqcoeff) {
+  const int64_t rd = RDCOST(rdmult, *accu_rate + non_skip_cost, *accu_dist);
+  const int64_t rd_new_eob = RDCOST(rdmult, skip_cost, 0);
+  if (rd_new_eob < rd) {
+    for (int si = 0; si < *eob; ++si) {
+      const int ci = scan[si];
+      qcoeff[ci] = 0;
+      dqcoeff[ci] = 0;
+    }
+    *accu_dist = 0;
+    *accu_rate = 0;
+    *eob = 0;
+  }
+}
+
+int av1_optimize_txb_new(const struct AV1_COMP *cpi, MACROBLOCK *x, int plane,
+                         int blk_row, int blk_col, int block, TX_SIZE tx_size,
+                         TXB_CTX *txb_ctx, int *rate_cost) {
+  const AV1_COMMON *cm = &cpi->common;
+  MACROBLOCKD *xd = &x->e_mbd;
+  const PLANE_TYPE plane_type = get_plane_type(plane);
+  const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
+  const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, blk_row, blk_col,
+                                          tx_size, cm->reduced_tx_set_used);
+  const MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
+  const struct macroblock_plane *p = &x->plane[plane];
+  struct macroblockd_plane *pd = &xd->plane[plane];
+  tran_low_t *qcoeff = BLOCK_OFFSET(p->qcoeff, block);
+  tran_low_t *dqcoeff = BLOCK_OFFSET(pd->dqcoeff, block);
+  const tran_low_t *tcoeff = BLOCK_OFFSET(p->coeff, block);
+  const int16_t *dequant = p->dequant_QTX;
+  const int bwl = get_txb_bwl(tx_size);
+  const int width = get_txb_wide(tx_size);
+  const int height = get_txb_high(tx_size);
+  assert(width == (1 << bwl));
+  const int is_inter = is_inter_block(mbmi);
+  const SCAN_ORDER *scan_order = get_scan(tx_size, tx_type);
+  const int16_t *scan = scan_order->scan;
+  const LV_MAP_COEFF_COST *txb_costs = &x->coeff_costs[txs_ctx][plane_type];
+  const int eob_multi_size = txsize_log2_minus4[tx_size];
+  const LV_MAP_EOB_COST *txb_eob_costs =
+      &x->eob_costs[eob_multi_size][plane_type];
+
+  const int shift = av1_get_tx_scale(tx_size);
+  const int64_t rdmult =
+      ((x->rdmult * plane_rd_mult[is_inter][plane_type] << (2 * (xd->bd - 8))) +
+       2) >>
+      2;
+
+  uint8_t levels_buf[TX_PAD_2D];
+  uint8_t *const levels = set_levels(levels_buf, width);
+  av1_txb_init_levels(qcoeff, width, height, levels);
+
+  // TODO(angirbird): check iqmatrix
+
+  const int non_skip_cost = txb_costs->txb_skip_cost[txb_ctx->txb_skip_ctx][0];
+  const int skip_cost = txb_costs->txb_skip_cost[txb_ctx->txb_skip_ctx][1];
+  int eob = p->eobs[block];
+  const int eob_cost = get_eob_cost(eob, txb_eob_costs, txb_costs, tx_type);
+  int accu_rate = eob_cost;
+  int64_t accu_dist = 0;
+  int si = eob - 1;
+  const int ci = scan[si];
+  const tran_low_t qc = qcoeff[ci];
+  const tran_low_t abs_qc = abs(qc);
+  const int sgn = qc < 0;
+  const int max_nz_num = 2;
+  int nz_num = 1;
+  int nz_ci[3] = { ci, 0, 0 };
+  if (abs_qc >= 2) {
+    update_coeff_general(&accu_rate, &accu_dist, si, eob, tx_size, tx_type, bwl,
+                         height, rdmult, shift, txb_ctx->dc_sign_ctx, dequant,
+                         scan, txb_costs, tcoeff, qcoeff, dqcoeff, levels);
+    --si;
+  } else {
+    assert(abs_qc == 1);
+    const int coeff_ctx = get_lower_levels_ctx_general(
+        1, si, bwl, height, levels, ci, tx_size, tx_type);
+    accu_rate += get_coeff_cost_general(1, ci, abs_qc, sgn, coeff_ctx,
+                                        txb_ctx->dc_sign_ctx, txb_costs, bwl,
+                                        tx_type, levels);
+    const tran_low_t tqc = tcoeff[ci];
+    const tran_low_t dqc = dqcoeff[ci];
+    const int64_t dist = get_coeff_dist(tqc, dqc, shift);
+    const int64_t dist0 = get_coeff_dist(tqc, 0, shift);
+    accu_dist += dist - dist0;
+    --si;
+  }
+
+  for (; si >= 0 && nz_num <= max_nz_num; --si) {
+    update_coeff_eob(&accu_rate, &accu_dist, &eob, &nz_num, nz_ci, si, tx_size,
+                     tx_type, bwl, height, txb_ctx->dc_sign_ctx, rdmult, shift,
+                     dequant, scan, txb_eob_costs, txb_costs, tcoeff, qcoeff,
+                     dqcoeff, levels);
+  }
+
+  for (; si >= 1; --si) {
+    update_coeff_simple(&accu_rate, &accu_dist, si, eob, tx_size, tx_type, bwl,
+                        rdmult, shift, dequant, scan, txb_costs, tcoeff, qcoeff,
+                        dqcoeff, levels);
+  }
+
+  // DC position
+  if (si == 0) {
+    update_coeff_general(&accu_rate, &accu_dist, si, eob, tx_size, tx_type, bwl,
+                         height, rdmult, shift, txb_ctx->dc_sign_ctx, dequant,
+                         scan, txb_costs, tcoeff, qcoeff, dqcoeff, levels);
+  }
+
+  update_skip(&accu_rate, &accu_dist, &eob, rdmult, skip_cost, non_skip_cost,
+              scan, qcoeff, dqcoeff);
+
+  const int tx_type_cost = av1_tx_type_cost(cm, x, xd, plane, tx_size, tx_type);
+  if (eob == 0)
+    accu_rate += skip_cost;
+  else
+    accu_rate += non_skip_cost + tx_type_cost;
+
+  p->eobs[block] = eob;
+  p->txb_entropy_ctx[block] =
+      av1_get_txb_entropy_context(qcoeff, scan_order, p->eobs[block]);
+
+  *rate_cost = accu_rate;
+  return eob;
+}
+
+// This functio is deprecated, but we keep it here becasue hash trellis
+// is not integrated with av1_optimize_txb_new yet
 int av1_optimize_txb(const struct AV1_COMP *cpi, MACROBLOCK *x, int plane,
                      int blk_row, int blk_col, int block, TX_SIZE tx_size,
                      TXB_CTX *txb_ctx, int fast_mode, int *rate_cost) {
diff --git a/av1/encoder/encodetxb.h b/av1/encoder/encodetxb.h
index 6ad5fec..5d41d19 100644
--- a/av1/encoder/encodetxb.h
+++ b/av1/encoder/encodetxb.h
@@ -76,9 +76,9 @@
                           int mi_row, int mi_col);
 
 void hbt_destroy();
-int av1_optimize_txb(const AV1_COMP *cpi, MACROBLOCK *x, int plane, int blk_row,
-                     int blk_col, int block, TX_SIZE tx_size, TXB_CTX *txb_ctx,
-                     int fast_mode, int *rate_cost);
+int av1_optimize_txb_new(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
+                         int blk_row, int blk_col, int block, TX_SIZE tx_size,
+                         TXB_CTX *txb_ctx, int *rate_cost);
 #ifdef __cplusplus
 }
 #endif
diff --git a/av1/encoder/rd.h b/av1/encoder/rd.h
index 47194d4..047f6e6 100644
--- a/av1/encoder/rd.h
+++ b/av1/encoder/rd.h
@@ -27,9 +27,9 @@
 #define RDDIV_BITS 7
 #define RD_EPB_SHIFT 6
 
-#define RDCOST(RM, R, D)                                          \
-  (ROUND_POWER_OF_TWO(((int64_t)R) * (RM), AV1_PROB_COST_SHIFT) + \
-   (D * (1 << RDDIV_BITS)))
+#define RDCOST(RM, R, D)                                            \
+  (ROUND_POWER_OF_TWO(((int64_t)(R)) * (RM), AV1_PROB_COST_SHIFT) + \
+   ((D) * (1 << RDDIV_BITS)))
 
 #define RDCOST_DBL(RM, R, D)                                       \
   (((((double)(R)) * (RM)) / (double)(1 << AV1_PROB_COST_SHIFT)) + \