Refactor get_br_cost.
Change-Id: Ibd14bdc8ec7184e2b9b830af925f643639995fdc
diff --git a/av1/encoder/encodetxb.c b/av1/encoder/encodetxb.c
index 8bf1c96..157ac4f 100644
--- a/av1/encoder/encodetxb.c
+++ b/av1/encoder/encodetxb.c
@@ -284,19 +284,9 @@
return av1_cost_literal(1);
}
-static INLINE int get_br_cost(tran_low_t abs_qc, int ctx,
- const int *coeff_lps) {
- const tran_low_t min_level = 1 + NUM_BASE_LEVELS;
- const tran_low_t max_level = 1 + NUM_BASE_LEVELS + COEFF_BASE_RANGE;
- (void)ctx;
- if (abs_qc >= min_level) {
- if (abs_qc >= max_level) {
- return coeff_lps[COEFF_BASE_RANGE]; // COEFF_BASE_RANGE * cost0;
- } else {
- return coeff_lps[(abs_qc - min_level)]; // * cost0 + cost1;
- }
- }
- return 0;
+static INLINE int get_br_cost(tran_low_t level, const int *coeff_lps) {
+ const int base_range = AOMMIN(level - 1 - NUM_BASE_LEVELS, COEFF_BASE_RANGE);
+ return coeff_lps[base_range];
}
static INLINE int get_golomb_cost(int abs_qc) {
@@ -331,7 +321,7 @@
if (abs_qc > NUM_BASE_LEVELS) {
const int ctx =
get_br_ctx(txb_info->levels, pos, txb_info->bwl, tx_class);
- cost += get_br_cost(abs_qc, ctx, txb_costs->lps_cost[ctx]);
+ cost += get_br_cost(abs_qc, txb_costs->lps_cost[ctx]);
cost += get_golomb_cost(abs_qc);
}
}
@@ -756,10 +746,8 @@
if (v) {
// sign bit cost
if (level > NUM_BASE_LEVELS) {
- const int ctx = get_br_ctx(levels, pos, bwl, tx_class);
- const int base_range =
- AOMMIN(level - 1 - NUM_BASE_LEVELS, COEFF_BASE_RANGE);
- cost += lps_cost[ctx][base_range];
+ const int ctx = get_br_ctx_eob(pos, bwl, tx_class);
+ cost += get_br_cost(level, lps_cost[ctx]);
cost += get_golomb_cost(level);
}
if (c) {
@@ -784,9 +772,7 @@
cost += av1_cost_literal(1);
if (level > NUM_BASE_LEVELS) {
const int ctx = get_br_ctx(levels, pos, bwl, tx_class);
- const int base_range =
- AOMMIN(level - 1 - NUM_BASE_LEVELS, COEFF_BASE_RANGE);
- cost += lps_cost[ctx][base_range];
+ cost += get_br_cost(level, lps_cost[ctx]);
cost += get_golomb_cost(level);
}
}
@@ -807,9 +793,7 @@
cost += coeff_costs->dc_sign_cost[dc_sign_ctx][sign01];
if (level > NUM_BASE_LEVELS) {
const int ctx = get_br_ctx(levels, pos, bwl, tx_class);
- const int base_range =
- AOMMIN(level - 1 - NUM_BASE_LEVELS, COEFF_BASE_RANGE);
- cost += lps_cost[ctx][base_range];
+ cost += get_br_cost(level, lps_cost[ctx]);
cost += get_golomb_cost(level);
}
}
@@ -1294,7 +1278,29 @@
cost += av1_cost_literal(1);
if (abs_qc > NUM_BASE_LEVELS) {
const int br_ctx = get_br_ctx(levels, ci, bwl, tx_class);
- cost += get_br_cost(abs_qc, br_ctx, txb_costs->lps_cost[br_ctx]);
+ cost += get_br_cost(abs_qc, txb_costs->lps_cost[br_ctx]);
+ cost += get_golomb_cost(abs_qc);
+ }
+ }
+ return cost;
+}
+
+static INLINE int get_coeff_cost_eob(int ci, tran_low_t abs_qc, int sign,
+ int coeff_ctx, int dc_sign_ctx,
+ const LV_MAP_COEFF_COST *txb_costs,
+ int bwl, TX_CLASS tx_class) {
+ int cost = 0;
+ cost += txb_costs->base_eob_cost[coeff_ctx][AOMMIN(abs_qc, 3) - 1];
+ if (abs_qc != 0) {
+ if (ci == 0) {
+ cost += txb_costs->dc_sign_cost[dc_sign_ctx][sign];
+ } else {
+ cost += av1_cost_literal(1);
+ }
+ if (abs_qc > NUM_BASE_LEVELS) {
+ int br_ctx;
+ br_ctx = get_br_ctx_eob(ci, bwl, tx_class);
+ cost += get_br_cost(abs_qc, txb_costs->lps_cost[br_ctx]);
cost += get_golomb_cost(abs_qc);
}
}
@@ -1325,7 +1331,7 @@
br_ctx = get_br_ctx_eob(ci, bwl, tx_class);
else
br_ctx = get_br_ctx(levels, ci, bwl, tx_class);
- cost += get_br_cost(abs_qc, br_ctx, txb_costs->lps_cost[br_ctx]);
+ cost += get_br_cost(abs_qc, txb_costs->lps_cost[br_ctx]);
cost += get_golomb_cost(abs_qc);
}
}
@@ -1484,17 +1490,17 @@
const int new_eob_cost =
get_eob_cost(new_eob, txb_eob_costs, txb_costs, tx_class);
int rate_coeff_eob =
- new_eob_cost + get_coeff_cost_general(1, ci, abs_qc, sign,
- coeff_ctx_new_eob, dc_sign_ctx,
- txb_costs, bwl, tx_class, levels);
+ new_eob_cost + get_coeff_cost_eob(ci, abs_qc, sign, coeff_ctx_new_eob,
+ dc_sign_ctx, txb_costs, bwl,
+ tx_class);
int64_t dist_new_eob = dist;
int64_t rd_new_eob = RDCOST(rdmult, rate_coeff_eob, dist_new_eob);
if (abs_qc_low > 0) {
const int rate_coeff_eob_low =
- new_eob_cost +
- get_coeff_cost_general(1, ci, abs_qc_low, sign, coeff_ctx_new_eob,
- dc_sign_ctx, txb_costs, bwl, tx_class, levels);
+ new_eob_cost + get_coeff_cost_eob(ci, abs_qc_low, sign,
+ coeff_ctx_new_eob, dc_sign_ctx,
+ txb_costs, bwl, tx_class);
const int64_t dist_new_eob_low = dist_low;
const int64_t rd_new_eob_low =
RDCOST(rdmult, rate_coeff_eob_low, dist_new_eob_low);