[tcq] Enable 8-state command line options
[tcq] Organize tcq params
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index 9f290ab..7fe6cc2 100644
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -322,31 +322,41 @@
# trellis quant
if (aom_config("CONFIG_DQ") eq "yes") {
- add_proto qw/void av1_decide_states/, "const struct tcq_node_t *prev, const struct tcq_rate_t *rd, const struct prequant_t *pq, int limits, int tru_eob, int64_t rdmult, struct tcq_node_t *decision";
+ add_proto qw/void av1_decide_states/, "const struct tcq_node_t *prev, const struct tcq_rate_t *rd, const struct prequant_t *pq, int n_states, int limits, int tru_eob, int64_t rdmult, struct tcq_node_t *decision";
specialize qw/av1_decide_states avx2/;
+ add_proto qw/void av1_decide_states_st4/, "const struct tcq_node_t *prev, const struct tcq_rate_t *rd, const struct prequant_t *pq, int n_states, int limits, int tru_eob, int64_t rdmult, struct tcq_node_t *decision";
+ specialize qw/av1_decide_states_st4 avx2/;
add_proto qw/void av1_pre_quant/, "tran_low_t tqc, struct prequant_t* pqData, const int32_t* quant_ptr, int dqv, int log_scale, int scan_pos";
specialize qw/av1_pre_quant avx2/;
add_proto qw/void av1_calc_diag_ctx/, "int scan_hi, int scan_lo, int bwl, const uint8_t *prev_levels, const int16_t* scan, uint8_t *ctx";
specialize qw/av1_calc_diag_ctx avx2/;
- add_proto qw/void av1_get_rate_dist_def_luma/, "const struct LV_MAP_COEFF_COST* txb_costs, const struct prequant_t *pq, const struct tcq_coeff_ctx_t *coeff_ctx, int blk_pos, int bwl, TX_CLASS tx_class, int diag_ctx, int eob_rate, struct tcq_rate_t *rd";
+
+ add_proto qw/void av1_get_rate_dist_def_luma/, "const struct LV_MAP_COEFF_COST* txb_costs, const struct prequant_t *pq, const struct tcq_coeff_ctx_t *coeff_ctx, int blk_pos, int bwl, TX_CLASS tx_class, int diag_ctx, int eob_rate, int n_states, struct tcq_rate_t *rd";
specialize qw/av1_get_rate_dist_def_luma avx2/;
- add_proto qw/void av1_get_rate_dist_def_chroma/, "const struct LV_MAP_COEFF_COST* txb_costs, const struct prequant_t *pq, const struct tcq_coeff_ctx_t *coeff_ctx, int blk_pos, int bwl, TX_CLASS tx_class, int diag_ctx, int eob_rate, int plane, int t_sign, int sign, struct tcq_rate_t *rd";
+ add_proto qw/void av1_get_rate_dist_def_luma_st4/, "const struct LV_MAP_COEFF_COST* txb_costs, const struct prequant_t *pq, const struct tcq_coeff_ctx_t *coeff_ctx, int blk_pos, int bwl, TX_CLASS tx_class, int diag_ctx, int eob_rate, int n_states, struct tcq_rate_t *rd";
+ specialize qw/av1_get_rate_dist_def_luma_st4 avx2/;
+ add_proto qw/void av1_get_rate_dist_def_chroma/, "const struct LV_MAP_COEFF_COST* txb_costs, const struct prequant_t *pq, const struct tcq_coeff_ctx_t *coeff_ctx, int blk_pos, int bwl, TX_CLASS tx_class, int diag_ctx, int eob_rate, int plane, int t_sign, int sign, int n_states, struct tcq_rate_t *rd";
specialize qw/av1_get_rate_dist_def_chroma avx2/;
- add_proto qw/void av1_get_rate_dist_lf_luma/, "const struct LV_MAP_COEFF_COST *txb_costs, const struct prequant_t *pq, const struct tcq_coeff_ctx_t *coeff_ctx, int blk_pos, int diag_ctx, int eob_rate, int dc_sign_ctx, int32_t *tmp_sign, int bwl, TX_CLASS tx_class, int coeff_sign, struct tcq_rate_t *rd";
+ add_proto qw/void av1_get_rate_dist_lf_luma/, "const struct LV_MAP_COEFF_COST *txb_costs, const struct prequant_t *pq, const struct tcq_coeff_ctx_t *coeff_ctx, int blk_pos, int diag_ctx, int eob_rate, int dc_sign_ctx, const int32_t *tmp_sign, int bwl, TX_CLASS tx_class, int coeff_sign, int n_states, struct tcq_rate_t *rd";
specialize qw/av1_get_rate_dist_lf_luma avx2/;
- add_proto qw/void av1_get_rate_dist_lf_chroma/, "const struct LV_MAP_COEFF_COST *txb_costs, const struct prequant_t *pq, const struct tcq_coeff_ctx_t *coeff_ctx, int blk_pos, int diag_ctx, int eob_rate, int dc_sign_ctx, int32_t *tmp_sign, int bwl, TX_CLASS tx_class, int plane, int coeff_sign, struct tcq_rate_t *rd";
+ add_proto qw/void av1_get_rate_dist_lf_luma_st4/, "const struct LV_MAP_COEFF_COST *txb_costs, const struct prequant_t *pq, const struct tcq_coeff_ctx_t *coeff_ctx, int blk_pos, int diag_ctx, int eob_rate, int dc_sign_ctx, const int32_t *tmp_sign, int bwl, TX_CLASS tx_class, int coeff_sign, int n_states, struct tcq_rate_t *rd";
+ specialize qw/av1_get_rate_dist_lf_luma_st4 avx2/;
+ add_proto qw/void av1_get_rate_dist_lf_chroma/, "const struct LV_MAP_COEFF_COST *txb_costs, const struct prequant_t *pq, const struct tcq_coeff_ctx_t *coeff_ctx, int blk_pos, int diag_ctx, int eob_rate, int dc_sign_ctx, const int32_t *tmp_sign, int bwl, TX_CLASS tx_class, int plane, int coeff_sign, int n_states, struct tcq_rate_t *rd";
specialize qw/av1_get_rate_dist_lf_chroma avx2/;
- add_proto qw/void av1_update_states/, "struct tcq_node_t *decision, int scan_idx, const struct tcq_ctx_t *cur_ctx, struct tcq_ctx_t *nxt_ctx";
- #specialize qw/av1_update_states avx2/;
+
+ add_proto qw/void av1_update_states/, "struct tcq_node_t *decision, int scan_idx, int n_states, const struct tcq_ctx_t *cur_ctx, struct tcq_ctx_t *nxt_ctx";
+ specialize qw/av1_update_states avx2/;
add_proto qw/void av1_init_lf_ctx/, "const uint8_t *lev, int scan_hi, int bwl, struct tcq_lf_ctx_t *lf_ctx";
specialize qw/av1_init_lf_ctx avx2/;
- add_proto qw/void av1_calc_lf_ctx/, "const struct tcq_lf_ctx_t *lf_ctx, int scan_pos, struct tcq_coeff_ctx_t *coeff_ctx";
- specialize qw/av1_calc_lf_ctx avx2/;
- add_proto qw/void av1_update_lf_ctx/, "const struct tcq_node_t *decision, struct tcq_lf_ctx_t *lf_ctx";
+ add_proto qw/void av1_calc_lf_ctx_st4/, "const struct tcq_lf_ctx_t *lf_ctx, int scan_pos, struct tcq_coeff_ctx_t *coeff_ctx";
+ specialize qw/av1_calc_lf_ctx_st4 avx2/;
+ add_proto qw/void av1_calc_lf_ctx_st8/, "const struct tcq_lf_ctx_t *lf_ctx, int scan_pos, struct tcq_coeff_ctx_t *coeff_ctx";
+ specialize qw/av1_calc_lf_ctx_st8 avx2/;
+ add_proto qw/void av1_update_lf_ctx/, "const struct tcq_node_t *decision, int n_states, struct tcq_lf_ctx_t *lf_ctx";
specialize qw/av1_update_lf_ctx avx2/;
add_proto qw/void av1_calc_block_eob_rate/, "struct macroblock *x, int plane, TX_SIZE tx_size, int eob, uint16_t *block_eob_rate";
specialize qw/av1_calc_block_eob_rate avx2/;
- add_proto qw/int av1_find_best_path/, "const struct tcq_node_t *trellis, const int16_t *scan, const int32_t *dequant, const qm_val_t *iqmatrix, const tran_low_t *tcoeff, int first_scan_pos, int log_scale, tran_low_t *qcoeff, tran_low_t *dqcoeff, int *min_rate, int64_t *min_cost";
+ add_proto qw/int av1_find_best_path/, "const struct tcq_node_t *trellis, int n_states_log2, const int16_t *scan, const int32_t *dequant, const qm_val_t *iqmatrix, const tran_low_t *tcoeff, int first_scan_pos, int log_scale, tran_low_t *qcoeff, tran_low_t *dqcoeff, int *min_rate, int64_t *min_cost";
specialize qw/av1_find_best_path avx2/;
}
diff --git a/av1/common/quant_common.h b/av1/common/quant_common.h
index f7613d9..d10ffe6 100644
--- a/av1/common/quant_common.h
+++ b/av1/common/quant_common.h
@@ -27,20 +27,14 @@
#define TCQ_HDR_FLAG 1 // Enable through header flag(s)
#define DQENABLE 0 // Determine whether to use DQ by dq_enable()
#define NEWQINDEX 1 // QP shift
-#define MORESTATES 0 // 1: 8-state; 0: 4-state
#define NEWHR 1 // 1:parity is determined by (base + LR)
#else
#define TCQ_HDR_FLAG 0
-#define DQENABLE 0 // Determine whether to use DQ by dq_enable()
-#define NEWQINDEX 0 // QP shift
-#define MORESTATES 0 // 1: 8-state; 0: 4-state
+#define DQENABLE 0 // Determine whether to use DQ by dq_enable()
+#define NEWQINDEX 0 // QP shift
#define NEWHR 0
#endif
-#if MORESTATES
-#define TOTALSTATES 8
-#else
-#define TOTALSTATES 4
-#endif
+#define TCQ_MAX_STATES 8
#define PHTHRESH 4
#define MINQ 0
diff --git a/av1/encoder/trellis_quant.c b/av1/encoder/trellis_quant.c
index 72cdcd9..0249ea7 100644
--- a/av1/encoder/trellis_quant.c
+++ b/av1/encoder/trellis_quant.c
@@ -24,6 +24,25 @@
#include "av1/encoder/rdopt.h"
#include "av1/encoder/tokenize.h"
+typedef void (*DecideStateFnc)(const struct tcq_node_t *prev,
+ const struct tcq_rate_t *rd,
+ const struct prequant_t *pq, int n_states,
+ int limits, int try_eob, int64_t rdmult,
+ struct tcq_node_t *decision);
+typedef void (*GetLfLumaRateDistFnc)(const struct LV_MAP_COEFF_COST *txb_costs,
+ const struct prequant_t *pq,
+ const struct tcq_coeff_ctx_t *coeff_ctx,
+ int blk_pos, int diag_ctx, int eob_rate,
+ int dc_sign_ctx, const int32_t *tmp_sign,
+ int bwl, TX_CLASS tx_class, int coeff_sign,
+ int n_states, struct tcq_rate_t *rd);
+typedef void (*GetDefLumaRateDistFnc)(const struct LV_MAP_COEFF_COST *txb_costs,
+ const struct prequant_t *pq,
+ const struct tcq_coeff_ctx_t *coeff_ctx,
+ int blk_pos, int bwl, TX_CLASS tx_class,
+ int diag_ctx, int eob_rate, int n_states,
+ struct tcq_rate_t *rd);
+
typedef struct {
uint8_t *base;
int bufsize;
@@ -46,19 +65,9 @@
return &lev->base[(2 * st + !lev->idx) * lev->bufsize];
}
-#if MORESTATES
-static const uint8_t next_st[TOTALSTATES][2] = { { 0, 4 }, { 4, 0 }, { 1, 5 },
- { 5, 1 }, { 6, 2 }, { 2, 6 },
- { 7, 3 }, { 3, 7 } };
-#else
-static const uint8_t next_st[TOTALSTATES][2] = {
- { 0, 2 }, { 2, 0 }, { 1, 3 }, { 3, 1 }
-};
-#endif
-
-static AOM_INLINE void init_tcq_decision(tcq_node_t *decision) {
+static AOM_INLINE void init_tcq_decision(tcq_node_t *decision, int n_states) {
static const tcq_node_t def = { INT64_MAX >> 10, INT32_MAX >> 1, -1, -2 };
- for (int state = 0; state < TOTALSTATES; state++) {
+ for (int state = 0; state < n_states; state++) {
memcpy(&decision[state], &def, sizeof(def));
}
}
@@ -376,7 +385,7 @@
static INLINE int get_coeff_cost_general(
int ci, tran_low_t abs_qc, int sign, int coeff_ctx, int mid_ctx,
int dc_sign_ctx, const LV_MAP_COEFF_COST *txb_costs, int bwl,
- TX_CLASS tx_class, int32_t *tmp_sign, int plane, int limits, int dq) {
+ TX_CLASS tx_class, const int32_t *tmp_sign, int plane, int limits, int dq) {
int cost = 0;
const int(*base_lf_cost_ptr)[DQ_CTXS][LF_BASE_SYMBOLS * 2] =
plane > 0 ? txb_costs->base_lf_cost_uv : txb_costs->base_lf_cost;
@@ -436,7 +445,7 @@
const LV_MAP_COEFF_COST *txb_costs,
int bwl, TX_CLASS tx_class,
#if CONFIG_CONTEXT_DERIVATION
- int32_t *tmp_sign,
+ const int32_t *tmp_sign,
#endif // CONFIG_CONTEXT_DERIVATION
int plane, int limits, int dq) {
int cost = 0;
@@ -762,16 +771,21 @@
void av1_decide_states_c(const struct tcq_node_t *prev,
const struct tcq_rate_t *rd,
- const struct prequant_t *pq, int limits, int try_eob,
- int64_t rdmult, struct tcq_node_t *decision) {
+ const struct prequant_t *pq, int n_states, int limits,
+ int try_eob, int64_t rdmult,
+ struct tcq_node_t *decision) {
const int32_t *rate = rd->rate;
const int32_t *rate_zero = rd->rate_zero;
const int32_t *rate_eob = rd->rate_eob;
- int64_t rdCost[2 * TOTALSTATES];
- int64_t rdCost_zero[TOTALSTATES];
+ int64_t rdCost[2 * TCQ_MAX_STATES];
+ int64_t rdCost_zero[TCQ_MAX_STATES];
int64_t rdCost_eob[2];
- for (int i = 0; i < TOTALSTATES; i++) {
+ // Init for ASAN
+ memset(rdCost, 0, sizeof(rdCost));
+ memset(rdCost_zero, 0, sizeof(rdCost_zero));
+
+ for (int i = 0; i < n_states; i++) {
int a0 = tcq_quant(i);
int a1 = a0 + 2;
int64_t dist0 = pq->deltaDist[a0];
@@ -782,54 +796,63 @@
}
rdCost_eob[0] = RDCOST(rdmult, rate_eob[0], pq->deltaDist[0]);
rdCost_eob[1] = RDCOST(rdmult, rate_eob[1], pq->deltaDist[2]);
-#if MORESTATES
- decide_new(rdCost[0], rdCost[1], rdCost_zero[0], rate[0], rate[1],
- rate_zero[0], pq->absLevel[0], pq->absLevel[2], limits,
- prev[0].rate, 0, &decision[0], &decision[4]);
- decide_new(rdCost[2], rdCost[3], rdCost_zero[1], rate[2], rate[3],
- rate_zero[1], pq->absLevel[0], pq->absLevel[2], limits,
- prev[1].rate, 1, &decision[4], &decision[0]);
- decide_new(rdCost[5], rdCost[4], rdCost_zero[2], rate[5], rate[4],
- rate_zero[2], pq->absLevel[3], pq->absLevel[1], limits,
- prev[2].rate, 2, &decision[1], &decision[5]);
- decide_new(rdCost[7], rdCost[6], rdCost_zero[3], rate[7], rate[6],
- rate_zero[3], pq->absLevel[3], pq->absLevel[1], limits,
- prev[3].rate, 3, &decision[5], &decision[1]);
- decide_new(rdCost[8], rdCost[9], rdCost_zero[4], rate[8], rate[9],
- rate_zero[4], pq->absLevel[0], pq->absLevel[2], limits,
- prev[4].rate, 4, &decision[6], &decision[2]);
- decide_new(rdCost[10], rdCost[11], rdCost_zero[5], rate[10], rate[11],
- rate_zero[5], pq->absLevel[0], pq->absLevel[2], limits,
- prev[5].rate, 5, &decision[2], &decision[6]);
- decide_new(rdCost[13], rdCost[12], rdCost_zero[6], rate[13], rate[12],
- rate_zero[6], pq->absLevel[3], pq->absLevel[1], limits,
- prev[6].rate, 6, &decision[7], &decision[3]);
- decide_new(rdCost[15], rdCost[14], rdCost_zero[7], rate[15], rate[14],
- rate_zero[7], pq->absLevel[3], pq->absLevel[1], limits,
- prev[7].rate, 7, &decision[3], &decision[7]);
-#else
- decide_new(rdCost[0], rdCost[1], rdCost_zero[0], rate[0], rate[1],
- rate_zero[0], pq->absLevel[0], pq->absLevel[2], limits,
- prev[0].rate, 0, &decision[0], &decision[2]);
- decide_new(rdCost[2], rdCost[3], rdCost_zero[1], rate[2], rate[3],
- rate_zero[1], pq->absLevel[0], pq->absLevel[2], limits,
- prev[1].rate, 1, &decision[2], &decision[0]);
- decide_new(rdCost[5], rdCost[4], rdCost_zero[2], rate[5], rate[4],
- rate_zero[2], pq->absLevel[3], pq->absLevel[1], limits,
- prev[2].rate, 2, &decision[1], &decision[3]);
- decide_new(rdCost[7], rdCost[6], rdCost_zero[3], rate[7], rate[6],
- rate_zero[3], pq->absLevel[3], pq->absLevel[1], limits,
- prev[3].rate, 3, &decision[3], &decision[1]);
-#endif
+ if (n_states == 4) {
+ decide_new(rdCost[0], rdCost[1], rdCost_zero[0], rate[0], rate[1],
+ rate_zero[0], pq->absLevel[0], pq->absLevel[2], limits,
+ prev[0].rate, 0, &decision[0], &decision[2]);
+ decide_new(rdCost[2], rdCost[3], rdCost_zero[1], rate[2], rate[3],
+ rate_zero[1], pq->absLevel[0], pq->absLevel[2], limits,
+ prev[1].rate, 1, &decision[2], &decision[0]);
+ decide_new(rdCost[5], rdCost[4], rdCost_zero[2], rate[5], rate[4],
+ rate_zero[2], pq->absLevel[3], pq->absLevel[1], limits,
+ prev[2].rate, 2, &decision[1], &decision[3]);
+ decide_new(rdCost[7], rdCost[6], rdCost_zero[3], rate[7], rate[6],
+ rate_zero[3], pq->absLevel[3], pq->absLevel[1], limits,
+ prev[3].rate, 3, &decision[3], &decision[1]);
+ } else { // n_states == 8
+ decide_new(rdCost[0], rdCost[1], rdCost_zero[0], rate[0], rate[1],
+ rate_zero[0], pq->absLevel[0], pq->absLevel[2], limits,
+ prev[0].rate, 0, &decision[0], &decision[4]);
+ decide_new(rdCost[2], rdCost[3], rdCost_zero[1], rate[2], rate[3],
+ rate_zero[1], pq->absLevel[0], pq->absLevel[2], limits,
+ prev[1].rate, 1, &decision[4], &decision[0]);
+ decide_new(rdCost[5], rdCost[4], rdCost_zero[2], rate[5], rate[4],
+ rate_zero[2], pq->absLevel[3], pq->absLevel[1], limits,
+ prev[2].rate, 2, &decision[1], &decision[5]);
+ decide_new(rdCost[7], rdCost[6], rdCost_zero[3], rate[7], rate[6],
+ rate_zero[3], pq->absLevel[3], pq->absLevel[1], limits,
+ prev[3].rate, 3, &decision[5], &decision[1]);
+ decide_new(rdCost[8], rdCost[9], rdCost_zero[4], rate[8], rate[9],
+ rate_zero[4], pq->absLevel[0], pq->absLevel[2], limits,
+ prev[4].rate, 4, &decision[6], &decision[2]);
+ decide_new(rdCost[10], rdCost[11], rdCost_zero[5], rate[10], rate[11],
+ rate_zero[5], pq->absLevel[0], pq->absLevel[2], limits,
+ prev[5].rate, 5, &decision[2], &decision[6]);
+ decide_new(rdCost[13], rdCost[12], rdCost_zero[6], rate[13], rate[12],
+ rate_zero[6], pq->absLevel[3], pq->absLevel[1], limits,
+ prev[6].rate, 6, &decision[7], &decision[3]);
+ decide_new(rdCost[15], rdCost[14], rdCost_zero[7], rate[15], rate[14],
+ rate_zero[7], pq->absLevel[3], pq->absLevel[1], limits,
+ prev[7].rate, 7, &decision[3], &decision[7]);
+ }
if (try_eob) {
- const int state0 = next_st[0][0];
- const int state1 = next_st[0][1];
+ const int state0 = 0;
+ const int state1 = n_states == 8 ? 4 : 2;
decide_eob(rdCost_eob[0], rdCost_eob[1], rate_eob[0], rate_eob[1],
pq->absLevel[0], pq->absLevel[2], &decision[state0],
&decision[state1]);
}
}
+void av1_decide_states_st4_c(const struct tcq_node_t *prev,
+ const struct tcq_rate_t *rd,
+ const struct prequant_t *pq, int n_states,
+ int limits, int try_eob, int64_t rdmult,
+ struct tcq_node_t *decision) {
+ av1_decide_states_c(prev, rd, pq, n_states, limits, try_eob, rdmult,
+ decision);
+}
+
void av1_pre_quant_c(tran_low_t tqc, struct prequant_t *pqData,
const int32_t *quant_ptr, int dqv, int log_scale,
int scan_pos) {
@@ -889,7 +912,7 @@
static int get_coeff_cost(int ci, tran_low_t abs_qc, int sign, int coeff_ctx,
int mid_ctx, int dc_sign_ctx,
const LV_MAP_COEFF_COST *txb_costs, int bwl,
- TX_CLASS tx_class, int32_t *tmp_sign, int plane,
+ TX_CLASS tx_class, const int32_t *tmp_sign, int plane,
int limits, int dq) {
return get_coeff_cost_general(ci, abs_qc, sign, coeff_ctx, mid_ctx,
dc_sign_ctx, txb_costs, bwl, tx_class,
@@ -899,28 +922,38 @@
plane, limits, dq);
}
-void trellis_first_pos(int scan_pos, int plane, TX_SIZE tx_size,
- TX_CLASS tx_class, int32_t *tmp_sign, int sharpness,
- tcq_levels_t *tcq_lev,
- tcq_node_t trellis[MAX_TRELLIS][TOTALSTATES],
- tran_low_t *qcoeff, const int64_t rdmult, int log_scale,
- const int16_t *scan, const tran_low_t *tcoeff,
- const int32_t *dequant, const int32_t *quant,
- const qm_val_t *iqmatrix, const uint16_t *block_eob_rate,
- const TXB_CTX *const txb_ctx,
- const LV_MAP_COEFF_COST *txb_costs) {
+void trellis_first_pos(const tcq_param_t *p, int scan_pos,
+ tcq_levels_t *tcq_lev, tcq_node_t *trellis) {
+ int n_states = p->n_states;
+ int n_states_log2 = p->n_states_log2;
+ int plane = p->plane;
+ TX_SIZE tx_size = p->tx_size;
+ TX_CLASS tx_class = p->tx_class;
+ int log_scale = p->log_scale;
+ int sharpness = p->sharpness;
+ int64_t rdmult = p->rdmult;
+ const int16_t *scan = p->scan;
+ const int32_t *tmp_sign = p->tmp_sign;
+ const tran_low_t *qcoeff = p->qcoeff;
+ const tran_low_t *tcoeff = p->tcoeff;
+ const int32_t *quant = p->quant;
+ const int32_t *dequant = p->dequant;
+ const qm_val_t *iqmatrix = p->iqmatrix;
+ const uint16_t *block_eob_rate = p->block_eob_rate;
+ const TXB_CTX *txb_ctx = p->txb_ctx;
+ const LV_MAP_COEFF_COST *txb_costs = p->txb_costs;
const int bwl = get_txb_bwl(tx_size);
const int height = get_txb_high(tx_size);
int blk_pos = scan[scan_pos];
- tcq_node_t *decision = trellis[scan_pos];
+ tcq_node_t *decision = &trellis[scan_pos << n_states_log2];
prequant_t pqData;
int tempdqv = get_dqv(dequant, scan[scan_pos], iqmatrix);
av1_pre_quant(tcoeff[blk_pos], &pqData, quant, tempdqv, log_scale, scan_pos);
// init state
- init_tcq_decision(decision);
+ init_tcq_decision(decision, n_states);
const int row = blk_pos >> bwl;
const int col = blk_pos - (row << bwl);
@@ -952,8 +985,8 @@
,
plane) +
eob_rate;
- const int state0 = next_st[0][0];
- const int state1 = next_st[0][1];
+ const int state0 = 0;
+ const int state1 = (n_states == 4) ? 2 : 4;
decide(0, pqData.deltaDist[0], pqData.deltaDist[2], rdmult, rate_Q0_a,
rate_Q0_b, INT32_MAX >> 1, pqData.absLevel[0], pqData.absLevel[2],
limits, 0, -1, &decision[state0], &decision[state1]);
@@ -970,7 +1003,7 @@
const struct prequant_t *pq,
const struct tcq_coeff_ctx_t *coeff_ctx,
int blk_pos, int bwl, TX_CLASS tx_class,
- int diag_ctx, int eob_rate,
+ int diag_ctx, int eob_rate, int n_states,
struct tcq_rate_t *rd) {
const int plane = 0;
const int t_sign = 0;
@@ -978,7 +1011,7 @@
const int dc_sign_ctx = 0;
const tran_low_t *absLevel = pq->absLevel;
- for (int i = 0; i < TOTALSTATES; i++) {
+ for (int i = 0; i < n_states; i++) {
int dq = tcq_quant(i);
int a0 = dq;
int a1 = a0 + 2;
@@ -1001,17 +1034,28 @@
bwl, tx_class, t_sign, plane);
}
+// Same as above function, but specialized to 4 states in the SIMD version.
+void av1_get_rate_dist_def_luma_st4_c(const struct LV_MAP_COEFF_COST *txb_costs,
+ const struct prequant_t *pq,
+ const struct tcq_coeff_ctx_t *coeff_ctx,
+ int blk_pos, int bwl, TX_CLASS tx_class,
+ int diag_ctx, int eob_rate, int n_states,
+ struct tcq_rate_t *rd) {
+ av1_get_rate_dist_def_luma_c(txb_costs, pq, coeff_ctx, blk_pos, bwl, tx_class,
+ diag_ctx, eob_rate, n_states, rd);
+}
+
void av1_get_rate_dist_def_chroma_c(const struct LV_MAP_COEFF_COST *txb_costs,
const struct prequant_t *pq,
const struct tcq_coeff_ctx_t *coeff_ctx,
int blk_pos, int bwl, TX_CLASS tx_class,
int diag_ctx, int eob_rate, int plane,
- int t_sign, int sign,
+ int t_sign, int sign, int n_states,
struct tcq_rate_t *rd) {
const tran_low_t *absLevel = pq->absLevel;
const int dc_sign_ctx = 0;
- for (int i = 0; i < TOTALSTATES; i++) {
+ for (int i = 0; i < n_states; i++) {
int dq = tcq_quant(i);
int a0 = dq;
int a1 = a0 + 2;
@@ -1038,28 +1082,28 @@
const struct prequant_t *pq,
const struct tcq_coeff_ctx_t *coeff_ctx,
int blk_pos, int diag_ctx, int eob_rate,
- int dc_sign_ctx, int32_t *tmp_sign, int bwl,
- TX_CLASS tx_class, int coeff_sign,
- struct tcq_rate_t *rd) {
+ int dc_sign_ctx, const int32_t *tmp_sign,
+ int bwl, TX_CLASS tx_class, int coeff_sign,
+ int n_states, struct tcq_rate_t *rd) {
const tran_low_t *absLevel = pq->absLevel;
- uint8_t base_ctx[TOTALSTATES];
- uint8_t mid_ctx[TOTALSTATES];
+ uint8_t base_ctx;
+ uint8_t mid_ctx;
int t_sign = tmp_sign[blk_pos];
int plane = 0;
- for (int i = 0; i < TOTALSTATES; i++) {
+ for (int i = 0; i < n_states; i++) {
int dq = tcq_quant(i);
int a0 = dq;
int a1 = a0 + 2;
- base_ctx[i] = (coeff_ctx->coef[i] & 15) + diag_ctx;
- mid_ctx[i] = coeff_ctx->coef[i] >> 4;
- int cost0 = get_coeff_cost(blk_pos, absLevel[a0], coeff_sign, base_ctx[i],
- mid_ctx[i], dc_sign_ctx, txb_costs, bwl,
- tx_class, tmp_sign, plane, 1, dq);
- int cost1 = get_coeff_cost(blk_pos, absLevel[a1], coeff_sign, base_ctx[i],
- mid_ctx[i], dc_sign_ctx, txb_costs, bwl,
- tx_class, tmp_sign, plane, 1, dq);
- rd->rate_zero[i] = txb_costs->base_lf_cost[base_ctx[i]][dq][0];
+ base_ctx = (coeff_ctx->coef[i] & 15) + diag_ctx;
+ mid_ctx = coeff_ctx->coef[i] >> 4;
+ int cost0 = get_coeff_cost(blk_pos, absLevel[a0], coeff_sign, base_ctx,
+ mid_ctx, dc_sign_ctx, txb_costs, bwl, tx_class,
+ tmp_sign, plane, 1, dq);
+ int cost1 = get_coeff_cost(blk_pos, absLevel[a1], coeff_sign, base_ctx,
+ mid_ctx, dc_sign_ctx, txb_costs, bwl, tx_class,
+ tmp_sign, plane, 1, dq);
+ rd->rate_zero[i] = txb_costs->base_lf_cost[base_ctx][dq][0];
rd->rate[2 * i] = cost0;
rd->rate[2 * i + 1] = cost1;
}
@@ -1073,31 +1117,45 @@
bwl, tx_class, t_sign, plane);
}
+// Same as above function, but specialized to 4 states in the SIMD version.
+void av1_get_rate_dist_lf_luma_st4_c(const struct LV_MAP_COEFF_COST *txb_costs,
+ const struct prequant_t *pq,
+ const struct tcq_coeff_ctx_t *coeff_ctx,
+ int blk_pos, int diag_ctx, int eob_rate,
+ int dc_sign_ctx, const int32_t *tmp_sign,
+ int bwl, TX_CLASS tx_class, int coeff_sign,
+ int n_states, struct tcq_rate_t *rd) {
+ av1_get_rate_dist_lf_luma_c(txb_costs, pq, coeff_ctx, blk_pos, diag_ctx,
+ eob_rate, dc_sign_ctx, tmp_sign, bwl, tx_class,
+ coeff_sign, n_states, rd);
+}
+
void av1_get_rate_dist_lf_chroma_c(const struct LV_MAP_COEFF_COST *txb_costs,
const struct prequant_t *pq,
const struct tcq_coeff_ctx_t *coeff_ctx,
int blk_pos, int diag_ctx, int eob_rate,
- int dc_sign_ctx, int32_t *tmp_sign, int bwl,
- TX_CLASS tx_class, int plane, int coeff_sign,
+ int dc_sign_ctx, const int32_t *tmp_sign,
+ int bwl, TX_CLASS tx_class, int plane,
+ int coeff_sign, int n_states,
struct tcq_rate_t *rd) {
const tran_low_t *absLevel = pq->absLevel;
- uint8_t base_ctx[TOTALSTATES];
- uint8_t mid_ctx[TOTALSTATES];
+ uint8_t base_ctx;
+ uint8_t mid_ctx;
int t_sign = tmp_sign[blk_pos];
- for (int i = 0; i < TOTALSTATES; i++) {
+ for (int i = 0; i < n_states; i++) {
int dq = tcq_quant(i);
int a0 = dq;
int a1 = a0 + 2;
- base_ctx[i] = (coeff_ctx->coef[i] & 15) + diag_ctx;
- mid_ctx[i] = coeff_ctx->coef[i] >> 4;
- int cost0 = get_coeff_cost(blk_pos, absLevel[a0], coeff_sign, base_ctx[i],
- mid_ctx[i], dc_sign_ctx, txb_costs, bwl,
- tx_class, tmp_sign, plane, 1, dq);
- int cost1 = get_coeff_cost(blk_pos, absLevel[a1], coeff_sign, base_ctx[i],
- mid_ctx[i], dc_sign_ctx, txb_costs, bwl,
- tx_class, tmp_sign, plane, 1, dq);
- rd->rate_zero[i] = txb_costs->base_lf_cost_uv[base_ctx[i]][dq][0];
+ base_ctx = (coeff_ctx->coef[i] & 15) + diag_ctx;
+ mid_ctx = coeff_ctx->coef[i] >> 4;
+ int cost0 = get_coeff_cost(blk_pos, absLevel[a0], coeff_sign, base_ctx,
+ mid_ctx, dc_sign_ctx, txb_costs, bwl, tx_class,
+ tmp_sign, plane, 1, dq);
+ int cost1 = get_coeff_cost(blk_pos, absLevel[a1], coeff_sign, base_ctx,
+ mid_ctx, dc_sign_ctx, txb_costs, bwl, tx_class,
+ tmp_sign, plane, 1, dq);
+ rd->rate_zero[i] = txb_costs->base_lf_cost_uv[base_ctx][dq][0];
rd->rate[2 * i] = cost0;
rd->rate[2 * i + 1] = cost1;
}
@@ -1124,10 +1182,10 @@
}
}
-void av1_update_states_c(tcq_node_t *decision, int scan_idx,
+void av1_update_states_c(tcq_node_t *decision, int scan_idx, int n_states,
const struct tcq_ctx_t *cur_ctx,
struct tcq_ctx_t *nxt_ctx) {
- for (int i = 0; i < TOTALSTATES; i++) {
+ for (int i = 0; i < n_states; i++) {
int prevId = decision[i].prevId;
int absLevel = decision[i].absLevel;
if (prevId >= 0) {
@@ -1143,8 +1201,9 @@
static void update_levels_diagonal(tcq_levels_t *tcq_lev, const int16_t *scan,
int bufsize, int bwl, int scan_hi,
- int scan_lo, const tcq_ctx_t *tcq_ctx) {
- for (int i = 0; i < TOTALSTATES; i++) {
+ int scan_lo, int n_states,
+ const tcq_ctx_t *tcq_ctx) {
+ for (int i = 0; i < n_states; i++) {
int orig_id = tcq_ctx[i].orig_id;
uint8_t *cur_lev = tcq_levels_cur(tcq_lev, i);
uint8_t *prev_lev = tcq_levels_prev(tcq_lev, orig_id);
@@ -1161,35 +1220,46 @@
}
// Handle trellis default region for Luma, TX_CLASS_2D blocks.
-void trellis_loop_diagonal(
- int scan_hi, int scan_lo, int plane, TX_SIZE tx_size, TX_CLASS tx_class,
- int32_t *tmp_sign, int sharpness, tcq_levels_t *tcq_lev,
- tcq_ctx_t tcq_ctx[TOTALSTATES],
- tcq_node_t trellis[MAX_TRELLIS][TOTALSTATES], tran_low_t *qcoeff,
- const int64_t rdmult, int log_scale, const int16_t *scan,
- const tran_low_t *tcoeff, const int32_t *dequant, const int32_t *quant,
- const qm_val_t *iqmatrix, const uint16_t *block_eob_rate,
- const TXB_CTX *const txb_ctx, const LV_MAP_COEFF_COST *txb_costs) {
+// TCQ 4-state
+static void trellis_loop_diagonal_st4(const tcq_param_t *p, int scan_hi,
+ int scan_lo, tcq_levels_t *tcq_lev,
+ tcq_ctx_t tcq_ctx[2 * TCQ_MAX_STATES],
+ tcq_node_t *trellis) {
+ int plane = p->plane;
+ TX_SIZE tx_size = p->tx_size;
+ TX_CLASS tx_class = p->tx_class;
+ int log_scale = p->log_scale;
+ int try_eob = p->sharpness == 0;
+ int64_t rdmult = p->rdmult;
+ const int16_t *scan = p->scan;
+ const tran_low_t *tcoeff = p->tcoeff;
+ const int32_t *quant = p->quant;
+ const int32_t *dequant = p->dequant;
+ const qm_val_t *iqmatrix = p->iqmatrix;
+ const uint16_t *block_eob_rate = p->block_eob_rate;
+ const LV_MAP_COEFF_COST *txb_costs = p->txb_costs;
const int bwl = get_txb_bwl(tx_size);
const int height = get_txb_high(tx_size);
const int pos0 = scan[scan_hi];
const int diag_ctx = get_nz_map_ctx_from_stats(0, pos0, bwl, TX_CLASS_2D, 0);
+
+ assert(p->n_states == 4);
assert(plane == 0);
assert(tx_class == TX_CLASS_2D);
(void)plane;
- (void)tmp_sign;
- (void)qcoeff;
- (void)txb_ctx;
+
+ const int n_st = 4;
+ const int n_st_log2 = 2;
// Precompute base and mid ctx values, as they are independent across
// the diagonal pass.
tcq_levels_swap(tcq_lev);
int i_ctx = scan_hi & 1;
- tcq_ctx_t *cur_ctx = &tcq_ctx[i_ctx ? TOTALSTATES : 0];
- tcq_ctx_t *nxt_ctx = &tcq_ctx[i_ctx ? 0 : TOTALSTATES];
+ tcq_ctx_t *cur_ctx = &tcq_ctx[i_ctx ? TCQ_MAX_STATES : 0];
+ tcq_ctx_t *nxt_ctx = &tcq_ctx[i_ctx ? 0 : TCQ_MAX_STATES];
- for (int i = 0; i < TOTALSTATES; i++) {
+ for (int i = 0; i < n_st; i++) {
uint8_t *prev_levels = tcq_levels_prev(tcq_lev, i);
av1_calc_diag_ctx(scan_hi, scan_lo, bwl, prev_levels, scan, cur_ctx[i].ctx);
cur_ctx[i].orig_id = i;
@@ -1197,8 +1267,8 @@
for (int scan_pos = scan_hi; scan_pos >= scan_lo; scan_pos--) {
const int blk_pos = scan[scan_pos];
- tcq_node_t *decision = trellis[scan_pos];
- tcq_node_t *prev_decision = trellis[scan_pos + 1];
+ tcq_node_t *decision = &trellis[scan_pos << n_st_log2];
+ tcq_node_t *prev_decision = &decision[n_st];
prequant_t pqData;
int tempdqv = get_dqv(dequant, scan[scan_pos], iqmatrix);
@@ -1206,12 +1276,12 @@
scan_pos);
// init state
- init_tcq_decision(decision);
+ init_tcq_decision(decision, n_st);
const int limits = 0;
// calculate rate distortion
tcq_coeff_ctx_t coeff_ctx;
- for (int i = 0; i < TOTALSTATES; i++) {
+ for (int i = 0; i < n_st; i++) {
coeff_ctx.coef[i] = cur_ctx[i].ctx[scan_pos - scan_lo];
}
int eob_ctx = get_lower_levels_ctx_eob(bwl, height, scan_pos);
@@ -1219,14 +1289,13 @@
int eob_rate = block_eob_rate[scan_pos];
tcq_rate_t rd;
- av1_get_rate_dist_def_luma(txb_costs, &pqData, &coeff_ctx, blk_pos, bwl,
- tx_class, diag_ctx, eob_rate, &rd);
+ av1_get_rate_dist_def_luma_st4(txb_costs, &pqData, &coeff_ctx, blk_pos, bwl,
+ tx_class, diag_ctx, eob_rate, n_st, &rd);
- int try_eob = sharpness == 0;
- av1_decide_states(prev_decision, &rd, &pqData, limits, try_eob, rdmult,
- decision);
+ av1_decide_states_st4(prev_decision, &rd, &pqData, n_st, limits, try_eob,
+ rdmult, decision);
- av1_update_states(decision, scan_pos - scan_lo, cur_ctx, nxt_ctx);
+ av1_update_states(decision, scan_pos - scan_lo, n_st, cur_ctx, nxt_ctx);
// Swap cur/nxt context.
tcq_ctx_t *tmp = cur_ctx;
@@ -1235,7 +1304,94 @@
}
update_levels_diagonal(tcq_lev, scan, tcq_lev->bufsize, bwl, scan_hi, scan_lo,
- cur_ctx);
+ n_st, cur_ctx);
+}
+
+// TCQ 8-state
+static void trellis_loop_diagonal_st8(const tcq_param_t *p, int scan_hi,
+ int scan_lo, tcq_levels_t *tcq_lev,
+ tcq_ctx_t tcq_ctx[2 * TCQ_MAX_STATES],
+ tcq_node_t *trellis) {
+ int plane = p->plane;
+ TX_SIZE tx_size = p->tx_size;
+ TX_CLASS tx_class = p->tx_class;
+ int log_scale = p->log_scale;
+ int try_eob = p->sharpness == 0;
+ int64_t rdmult = p->rdmult;
+ const int16_t *scan = p->scan;
+ const tran_low_t *tcoeff = p->tcoeff;
+ const int32_t *quant = p->quant;
+ const int32_t *dequant = p->dequant;
+ const qm_val_t *iqmatrix = p->iqmatrix;
+ const uint16_t *block_eob_rate = p->block_eob_rate;
+ const LV_MAP_COEFF_COST *txb_costs = p->txb_costs;
+ const int bwl = get_txb_bwl(tx_size);
+ const int height = get_txb_high(tx_size);
+ const int pos0 = scan[scan_hi];
+ const int diag_ctx = get_nz_map_ctx_from_stats(0, pos0, bwl, TX_CLASS_2D, 0);
+
+ assert(p->n_states == 8);
+ assert(plane == 0);
+ assert(tx_class == TX_CLASS_2D);
+ (void)plane;
+
+ const int n_st = 8;
+ const int n_st_log2 = 3;
+
+ // Precompute base and mid ctx values, as they are independent across
+ // the diagonal pass.
+ tcq_levels_swap(tcq_lev);
+
+ int i_ctx = scan_hi & 1;
+ tcq_ctx_t *cur_ctx = &tcq_ctx[i_ctx ? TCQ_MAX_STATES : 0];
+ tcq_ctx_t *nxt_ctx = &tcq_ctx[i_ctx ? 0 : TCQ_MAX_STATES];
+
+ for (int i = 0; i < n_st; i++) {
+ uint8_t *prev_levels = tcq_levels_prev(tcq_lev, i);
+ av1_calc_diag_ctx(scan_hi, scan_lo, bwl, prev_levels, scan, cur_ctx[i].ctx);
+ cur_ctx[i].orig_id = i;
+ }
+
+ for (int scan_pos = scan_hi; scan_pos >= scan_lo; scan_pos--) {
+ const int blk_pos = scan[scan_pos];
+ tcq_node_t *decision = &trellis[scan_pos << n_st_log2];
+ tcq_node_t *prev_decision = &decision[n_st];
+
+ prequant_t pqData;
+ int tempdqv = get_dqv(dequant, scan[scan_pos], iqmatrix);
+ av1_pre_quant(tcoeff[blk_pos], &pqData, quant, tempdqv, log_scale,
+ scan_pos);
+
+ // init state
+ init_tcq_decision(decision, n_st);
+ const int limits = 0;
+
+ // calculate rate distortion
+ tcq_coeff_ctx_t coeff_ctx;
+ for (int i = 0; i < n_st; i++) {
+ coeff_ctx.coef[i] = cur_ctx[i].ctx[scan_pos - scan_lo];
+ }
+ int eob_ctx = get_lower_levels_ctx_eob(bwl, height, scan_pos);
+ coeff_ctx.coef_eob = eob_ctx;
+ int eob_rate = block_eob_rate[scan_pos];
+
+ tcq_rate_t rd;
+ av1_get_rate_dist_def_luma(txb_costs, &pqData, &coeff_ctx, blk_pos, bwl,
+ tx_class, diag_ctx, eob_rate, n_st, &rd);
+
+ av1_decide_states(prev_decision, &rd, &pqData, n_st, limits, try_eob,
+ rdmult, decision);
+
+ av1_update_states(decision, scan_pos - scan_lo, n_st, cur_ctx, nxt_ctx);
+
+ // Swap cur/nxt context.
+ tcq_ctx_t *tmp = cur_ctx;
+ cur_ctx = nxt_ctx;
+ nxt_ctx = tmp;
+ }
+
+ update_levels_diagonal(tcq_lev, scan, tcq_lev->bufsize, bwl, scan_hi, scan_lo,
+ n_st, cur_ctx);
}
void av1_init_lf_ctx_c(const uint8_t *lev, int scan_hi, int bwl,
@@ -1261,8 +1417,8 @@
// Initialize LF neighbor context.
// The lf_ctx->last[] array tracks the last N previous coeffs (LIFO),
// and used to calculate coeff neighbor contexts.
-void av1_calc_lf_ctx_c(const struct tcq_lf_ctx_t *lf_ctx, int scan_pos,
- struct tcq_coeff_ctx_t *coeff_ctx) {
+void av1_calc_lf_ctx_st4_c(const struct tcq_lf_ctx_t *lf_ctx, int scan_pos,
+ struct tcq_coeff_ctx_t *coeff_ctx) {
static const int8_t kMaxCtx[16] = { 8, 6, 6, 4, 4, 4, 4, 4,
4, 4, 4, 4, 4, 4, 4, 4 };
static const int8_t kScanDiag[MAX_LF_SCAN] = { 0, 1, 1, 2, 2, 2, 3, 3, 3, 3 };
@@ -1272,8 +1428,9 @@
{ 0, 0, 3, 3, 0, 0, 1, 3, 1, 0, 0 }, // diag 2
{ 0, 0, 0, 3, 3, 0, 0, 0, 1, 3, 1 }, // diag 3
};
+ int n_states = 4;
- for (int st = 0; st < TOTALSTATES; st++) {
+ for (int st = 0; st < n_states; st++) {
int diag = kScanDiag[scan_pos];
int base = 0;
int mid = 0;
@@ -1292,12 +1449,44 @@
}
}
-void av1_update_lf_ctx_c(const struct tcq_node_t *decision,
- struct tcq_lf_ctx_t *lf_ctx) {
- tcq_lf_ctx_t save[TOTALSTATES];
- memcpy(save, lf_ctx, sizeof(tcq_lf_ctx_t) * TOTALSTATES);
+void av1_calc_lf_ctx_st8_c(const struct tcq_lf_ctx_t *lf_ctx, int scan_pos,
+ struct tcq_coeff_ctx_t *coeff_ctx) {
+ static const int8_t kMaxCtx[16] = { 8, 6, 6, 4, 4, 4, 4, 4,
+ 4, 4, 4, 4, 4, 4, 4, 4 };
+ static const int8_t kScanDiag[MAX_LF_SCAN] = { 0, 1, 1, 2, 2, 2, 3, 3, 3, 3 };
+ static const int8_t kNbrMask[4][11] = {
+ { 3, 3, 1, 3, 1, 0, 0, 0, 0, 0, 0 }, // diag 0
+ { 0, 3, 3, 0, 1, 3, 1, 0, 0, 0, 0 }, // diag 1
+ { 0, 0, 3, 3, 0, 0, 1, 3, 1, 0, 0 }, // diag 2
+ { 0, 0, 0, 3, 3, 0, 0, 0, 1, 3, 1 }, // diag 3
+ };
+ int n_states = 8;
- for (int st = 0; st < TOTALSTATES; st++) {
+ for (int st = 0; st < n_states; st++) {
+ int diag = kScanDiag[scan_pos];
+ int base = 0;
+ int mid = 0;
+ for (int i = 0; i < 11; i++) {
+ int mask = kNbrMask[diag][i];
+ if (mask) {
+ base += AOMMIN(lf_ctx[st].last[i], 5);
+ if (mask >> 1) {
+ mid += AOMMIN(lf_ctx[st].last[i], MAX_VAL_BR_CTX);
+ }
+ }
+ }
+ int base_ctx = AOMMIN((base + 1) >> 1, kMaxCtx[scan_pos]);
+ int mid_ctx = AOMMIN((mid + 1) >> 1, 6) + ((scan_pos == 0) ? 0 : 7);
+ coeff_ctx->coef[st] = base_ctx + (mid_ctx << 4);
+ }
+}
+
+void av1_update_lf_ctx_c(const struct tcq_node_t *decision, int n_states,
+ struct tcq_lf_ctx_t *lf_ctx) {
+ tcq_lf_ctx_t save[TCQ_MAX_STATES];
+ memcpy(save, lf_ctx, sizeof(tcq_lf_ctx_t) * n_states);
+
+ for (int st = 0; st < n_states; st++) {
int absLevel = decision[st].absLevel;
int prevId = decision[st].prevId;
int new_eob = prevId < 0;
@@ -1313,25 +1502,33 @@
}
// Handle trellis Low-freq (LF) region for Luma, TX_CLASS_2D blocks.
-void trellis_loop_lf(int scan_hi, int scan_lo, int plane, TX_SIZE tx_size,
- TX_CLASS tx_class, int32_t *tmp_sign, int sharpness,
- tcq_levels_t *tcq_lev,
- tcq_node_t trellis[MAX_TRELLIS][TOTALSTATES],
- tran_low_t *qcoeff, const int64_t rdmult, int log_scale,
- const int16_t *scan, const tran_low_t *tcoeff,
- const int32_t *dequant, const int32_t *quant,
- const qm_val_t *iqmatrix, const uint16_t *block_eob_rate,
- const TXB_CTX *const txb_ctx,
- const LV_MAP_COEFF_COST *txb_costs) {
+// TCQ 4-state
+static void trellis_loop_lf_st4(const tcq_param_t *p, int scan_hi, int scan_lo,
+ tcq_levels_t *tcq_lev, tcq_node_t *trellis) {
+ TX_SIZE tx_size = p->tx_size;
+ TX_CLASS tx_class = p->tx_class;
+ int log_scale = p->log_scale;
+ int try_eob = p->sharpness == 0;
+ int64_t rdmult = p->rdmult;
+ const int16_t *scan = p->scan;
+ const int32_t *tmp_sign = p->tmp_sign;
+ const tran_low_t *tcoeff = p->tcoeff;
+ const int32_t *quant = p->quant;
+ const int32_t *dequant = p->dequant;
+ const qm_val_t *iqmatrix = p->iqmatrix;
+ const uint16_t *block_eob_rate = p->block_eob_rate;
+ const TXB_CTX *txb_ctx = p->txb_ctx;
+ const LV_MAP_COEFF_COST *txb_costs = p->txb_costs;
const int bwl = get_txb_bwl(tx_size);
const int height = get_txb_high(tx_size);
- assert(plane == 0);
+ assert(p->plane == 0);
assert(tx_class == TX_CLASS_2D);
- (void)plane;
- (void)qcoeff;
- tcq_lf_ctx_t lf_ctx[TOTALSTATES];
- for (int i = 0; i < TOTALSTATES; i++) {
+ const int n_st = 4;
+ const int n_st_log2 = 2;
+
+ tcq_lf_ctx_t lf_ctx[TCQ_MAX_STATES];
+ for (int i = 0; i < n_st; i++) {
uint8_t *lev = tcq_levels_cur(tcq_lev, i);
av1_init_lf_ctx(lev, scan_hi, bwl, &lf_ctx[i]);
}
@@ -1339,8 +1536,8 @@
for (int scan_pos = scan_hi; scan_pos >= scan_lo; scan_pos--) {
int blk_pos = scan[scan_pos];
- tcq_node_t *decision = trellis[scan_pos];
- tcq_node_t *prd = trellis[scan_pos + 1];
+ tcq_node_t *decision = &trellis[scan_pos << n_st_log2];
+ tcq_node_t *prd = &decision[n_st];
prequant_t pqData;
int tempdqv = get_dqv(dequant, scan[scan_pos], iqmatrix);
@@ -1348,50 +1545,134 @@
scan_pos);
// init state
- init_tcq_decision(decision);
+ init_tcq_decision(decision, n_st);
const int coeff_sign = tcoeff[blk_pos] < 0;
const int limits = 1; // Always in LF region.
// calculate contexts
tcq_coeff_ctx_t coeff_ctx;
int diag_ctx = get_nz_map_ctx_from_stats_lf(0, blk_pos, bwl, tx_class);
- av1_calc_lf_ctx(lf_ctx, scan_pos, &coeff_ctx);
int eob_ctx = get_lower_levels_ctx_eob(bwl, height, scan_pos);
int eob_rate = block_eob_rate[scan_pos];
coeff_ctx.coef_eob = eob_ctx;
+ av1_calc_lf_ctx_st4(lf_ctx, scan_pos, &coeff_ctx);
+
+ // calculate rate distortion
+ tcq_rate_t rd;
+ av1_get_rate_dist_lf_luma_st4(
+ txb_costs, &pqData, &coeff_ctx, blk_pos, diag_ctx, eob_rate,
+ txb_ctx->dc_sign_ctx, tmp_sign, bwl, tx_class, coeff_sign, n_st, &rd);
+
+ av1_decide_states_st4(prd, &rd, &pqData, n_st, limits, try_eob, rdmult,
+ decision);
+
+ av1_update_lf_ctx(decision, n_st, lf_ctx);
+ }
+}
+
+// TCQ 8-state
+static void trellis_loop_lf_st8(const tcq_param_t *p, int scan_hi, int scan_lo,
+ tcq_levels_t *tcq_lev, tcq_node_t *trellis) {
+ TX_SIZE tx_size = p->tx_size;
+ TX_CLASS tx_class = p->tx_class;
+ int log_scale = p->log_scale;
+ int try_eob = p->sharpness == 0;
+ int64_t rdmult = p->rdmult;
+ const int16_t *scan = p->scan;
+ const int32_t *tmp_sign = p->tmp_sign;
+ const tran_low_t *tcoeff = p->tcoeff;
+ const int32_t *quant = p->quant;
+ const int32_t *dequant = p->dequant;
+ const qm_val_t *iqmatrix = p->iqmatrix;
+ const uint16_t *block_eob_rate = p->block_eob_rate;
+ const TXB_CTX *txb_ctx = p->txb_ctx;
+ const LV_MAP_COEFF_COST *txb_costs = p->txb_costs;
+ const int bwl = get_txb_bwl(tx_size);
+ const int height = get_txb_high(tx_size);
+ assert(p->plane == 0);
+ assert(tx_class == TX_CLASS_2D);
+
+ const int n_st = 8;
+ const int n_st_log2 = 3;
+
+ tcq_lf_ctx_t lf_ctx[TCQ_MAX_STATES];
+ for (int i = 0; i < n_st; i++) {
+ uint8_t *lev = tcq_levels_cur(tcq_lev, i);
+ av1_init_lf_ctx(lev, scan_hi, bwl, &lf_ctx[i]);
+ }
+
+ for (int scan_pos = scan_hi; scan_pos >= scan_lo; scan_pos--) {
+ int blk_pos = scan[scan_pos];
+
+ tcq_node_t *decision = &trellis[scan_pos << n_st_log2];
+ tcq_node_t *prd = &decision[n_st];
+
+ prequant_t pqData;
+ int tempdqv = get_dqv(dequant, scan[scan_pos], iqmatrix);
+ av1_pre_quant(tcoeff[blk_pos], &pqData, quant, tempdqv, log_scale,
+ scan_pos);
+
+ // init state
+ init_tcq_decision(decision, n_st);
+ const int coeff_sign = tcoeff[blk_pos] < 0;
+ const int limits = 1; // Always in LF region.
+
+ // calculate contexts
+ tcq_coeff_ctx_t coeff_ctx;
+ int diag_ctx = get_nz_map_ctx_from_stats_lf(0, blk_pos, bwl, tx_class);
+ int eob_ctx = get_lower_levels_ctx_eob(bwl, height, scan_pos);
+ int eob_rate = block_eob_rate[scan_pos];
+ coeff_ctx.coef_eob = eob_ctx;
+ av1_calc_lf_ctx_st8(lf_ctx, scan_pos, &coeff_ctx);
// calculate rate distortion
tcq_rate_t rd;
av1_get_rate_dist_lf_luma(txb_costs, &pqData, &coeff_ctx, blk_pos, diag_ctx,
eob_rate, txb_ctx->dc_sign_ctx, tmp_sign, bwl,
- tx_class, coeff_sign, &rd);
+ tx_class, coeff_sign, n_st, &rd);
- int try_eob = sharpness == 0;
- av1_decide_states(prd, &rd, &pqData, limits, try_eob, rdmult, decision);
+ av1_decide_states(prd, &rd, &pqData, n_st, limits, try_eob, rdmult,
+ decision);
- av1_update_lf_ctx(decision, lf_ctx);
+ av1_update_lf_ctx(decision, n_st, lf_ctx);
}
}
-void trellis_loop(int first_scan_pos, int scan_hi, int scan_lo, int plane,
- TX_SIZE tx_size, TX_CLASS tx_class, int32_t *tmp_sign,
- int sharpness, tcq_levels_t *tcq_lev,
- tcq_node_t trellis[MAX_TRELLIS][TOTALSTATES],
- tran_low_t *qcoeff, const int64_t rdmult, int log_scale,
- const int16_t *scan, const tran_low_t *tcoeff,
- const int32_t *dequant, const int32_t *quant,
- const qm_val_t *iqmatrix, const uint16_t *block_eob_rate,
- const TXB_CTX *const txb_ctx,
- const LV_MAP_COEFF_COST *txb_costs) {
+void trellis_loop(const tcq_param_t *p, int first_scan_pos, int scan_hi,
+ int scan_lo, tcq_levels_t *tcq_lev, tcq_node_t *trellis) {
+ int n_states = p->n_states;
+ int n_states_log2 = p->n_states_log2;
+ int plane = p->plane;
+ TX_SIZE tx_size = p->tx_size;
+ TX_CLASS tx_class = p->tx_class;
+ int log_scale = p->log_scale;
+ int sharpness = p->sharpness;
+ int try_eob = sharpness == 0;
+ int64_t rdmult = p->rdmult;
+ const int16_t *scan = p->scan;
+ const int32_t *tmp_sign = p->tmp_sign;
+ const tran_low_t *tcoeff = p->tcoeff;
+ const int32_t *quant = p->quant;
+ const int32_t *dequant = p->dequant;
+ const qm_val_t *iqmatrix = p->iqmatrix;
+ const uint16_t *block_eob_rate = p->block_eob_rate;
+ const TXB_CTX *txb_ctx = p->txb_ctx;
+ const LV_MAP_COEFF_COST *txb_costs = p->txb_costs;
const int bwl = get_txb_bwl(tx_size);
const int height = get_txb_high(tx_size);
- (void)qcoeff;
+ DecideStateFnc f_decide_states =
+ n_states == 4 ? av1_decide_states_st4 : av1_decide_states;
+ GetDefLumaRateDistFnc f_get_rate_dist_def_luma =
+ n_states == 4 ? av1_get_rate_dist_def_luma_st4
+ : av1_get_rate_dist_def_luma;
+ GetLfLumaRateDistFnc f_get_rate_dist_lf_luma =
+ n_states == 4 ? av1_get_rate_dist_lf_luma_st4 : av1_get_rate_dist_lf_luma;
for (int scan_pos = scan_hi; scan_pos >= scan_lo; scan_pos--) {
tcq_levels_swap(tcq_lev);
- uint8_t *levels[TOTALSTATES];
- uint8_t *prev_levels[TOTALSTATES];
- for (int i = 0; i < TOTALSTATES; i++) {
+ uint8_t *levels[TCQ_MAX_STATES];
+ uint8_t *prev_levels[TCQ_MAX_STATES];
+ for (int i = 0; i < n_states; i++) {
prev_levels[i] = tcq_levels_prev(tcq_lev, i);
levels[i] = tcq_levels_cur(tcq_lev, i);
}
@@ -1401,8 +1682,8 @@
int col = blk_pos - (row << bwl);
int limits = get_lf_limits(row, col, tx_class, plane);
- tcq_node_t *decision = trellis[scan_pos];
- tcq_node_t *prd = trellis[scan_pos + 1];
+ tcq_node_t *decision = &trellis[scan_pos << n_states_log2];
+ tcq_node_t *prd = &decision[n_states];
prequant_t pqData;
int tempdqv = get_dqv(dequant, scan[scan_pos], iqmatrix);
@@ -1410,77 +1691,78 @@
scan_pos);
// init state
- init_tcq_decision(decision);
+ init_tcq_decision(decision, n_states);
const int coeff_sign = tcoeff[blk_pos] < 0;
// calculate contexts
- int diag_ctx =
- (limits && plane == 0)
- ? get_nz_map_ctx_from_stats_lf(0, blk_pos, bwl, tx_class)
- : plane == 0 ? get_nz_map_ctx_from_stats(0, blk_pos, bwl, tx_class, 0)
- : limits
- ? get_nz_map_ctx_from_stats_lf_chroma(0, tx_class, plane)
- : get_nz_map_ctx_from_stats_chroma(0, blk_pos, tx_class, plane);
tcq_coeff_ctx_t coeff_ctx;
- if (limits) {
- for (int i = 0; i < TOTALSTATES; i++) {
- int base_ctx = plane
- ? get_lower_levels_lf_ctx_chroma(
- prev_levels[i], blk_pos, bwl, tx_class, plane)
- : get_lower_levels_lf_ctx(prev_levels[i], blk_pos,
- bwl, tx_class);
- int br_ctx =
- plane ? get_br_lf_ctx_chroma(prev_levels[i], blk_pos, bwl, tx_class)
- : get_br_lf_ctx(prev_levels[i], blk_pos, bwl, tx_class);
- coeff_ctx.coef[i] = base_ctx - diag_ctx + (br_ctx << 4);
- }
- } else {
- for (int i = 0; i < TOTALSTATES; i++) {
- int base_ctx =
- plane ? get_lower_levels_ctx_chroma(prev_levels[i], blk_pos, bwl,
- tx_class, plane)
- : get_lower_levels_ctx(prev_levels[i], blk_pos, bwl, tx_class
-#if CONFIG_CHROMA_TX_COEFF_CODING
- ,
- plane
-#endif
- );
- int br_ctx =
- plane ? get_br_ctx_chroma(prev_levels[i], blk_pos, bwl, tx_class)
- : get_br_ctx(prev_levels[i], blk_pos, bwl, tx_class);
- coeff_ctx.coef[i] = base_ctx - diag_ctx + (br_ctx << 4);
- }
- }
int eob_ctx = get_lower_levels_ctx_eob(bwl, height, scan_pos);
int eob_rate = block_eob_rate[scan_pos];
coeff_ctx.coef_eob = eob_ctx;
- // calculate rate distortion
tcq_rate_t rd;
- if (limits && plane == 0) {
- av1_get_rate_dist_lf_luma(txb_costs, &pqData, &coeff_ctx, blk_pos,
+
+ // Calculate contexts and rate distortion
+ if (limits) {
+ if (plane == 0) {
+ int diag_ctx = get_nz_map_ctx_from_stats_lf(0, blk_pos, bwl, tx_class);
+ for (int i = 0; i < n_states; i++) {
+ int base_ctx =
+ get_lower_levels_lf_ctx(prev_levels[i], blk_pos, bwl, tx_class);
+ int br_ctx = get_br_lf_ctx(prev_levels[i], blk_pos, bwl, tx_class);
+ coeff_ctx.coef[i] = base_ctx - diag_ctx + (br_ctx << 4);
+ }
+ f_get_rate_dist_lf_luma(txb_costs, &pqData, &coeff_ctx, blk_pos,
diag_ctx, eob_rate, txb_ctx->dc_sign_ctx,
- tmp_sign, bwl, tx_class, coeff_sign, &rd);
- } else if (limits) {
- av1_get_rate_dist_lf_chroma(txb_costs, &pqData, &coeff_ctx, blk_pos,
- diag_ctx, eob_rate, txb_ctx->dc_sign_ctx,
- tmp_sign, bwl, tx_class, plane, coeff_sign,
- &rd);
- } else if (plane == 0) {
- av1_get_rate_dist_def_luma(txb_costs, &pqData, &coeff_ctx, blk_pos, bwl,
- tx_class, diag_ctx, eob_rate, &rd);
+ tmp_sign, bwl, tx_class, coeff_sign, n_states,
+ &rd);
+ } else {
+ int diag_ctx = get_nz_map_ctx_from_stats_lf_chroma(0, tx_class, plane);
+ for (int i = 0; i < n_states; i++) {
+ int base_ctx = get_lower_levels_lf_ctx_chroma(prev_levels[i], blk_pos,
+ bwl, tx_class, plane);
+ int br_ctx =
+ get_br_lf_ctx_chroma(prev_levels[i], blk_pos, bwl, tx_class);
+ coeff_ctx.coef[i] = base_ctx - diag_ctx + (br_ctx << 4);
+ }
+ av1_get_rate_dist_lf_chroma(txb_costs, &pqData, &coeff_ctx, blk_pos,
+ diag_ctx, eob_rate, txb_ctx->dc_sign_ctx,
+ tmp_sign, bwl, tx_class, plane, coeff_sign,
+ n_states, &rd);
+ }
} else {
- av1_get_rate_dist_def_chroma(txb_costs, &pqData, &coeff_ctx, blk_pos, bwl,
- tx_class, diag_ctx, eob_rate, plane,
- tmp_sign[blk_pos], coeff_sign, &rd);
+ if (plane == 0) {
+ int diag_ctx = get_nz_map_ctx_from_stats(0, blk_pos, bwl, tx_class, 0);
+ for (int i = 0; i < n_states; i++) {
+ int base_ctx = get_lower_levels_ctx(prev_levels[i], blk_pos, bwl,
+ tx_class, plane);
+ int br_ctx = get_br_ctx(prev_levels[i], blk_pos, bwl, tx_class);
+ coeff_ctx.coef[i] = base_ctx - diag_ctx + (br_ctx << 4);
+ }
+ f_get_rate_dist_def_luma(txb_costs, &pqData, &coeff_ctx, blk_pos, bwl,
+ tx_class, diag_ctx, eob_rate, n_states, &rd);
+ } else {
+ int diag_ctx =
+ get_nz_map_ctx_from_stats_chroma(0, blk_pos, tx_class, plane);
+ for (int i = 0; i < n_states; i++) {
+ int base_ctx = get_lower_levels_ctx_chroma(prev_levels[i], blk_pos,
+ bwl, tx_class, plane);
+ int br_ctx =
+ get_br_ctx_chroma(prev_levels[i], blk_pos, bwl, tx_class);
+ coeff_ctx.coef[i] = base_ctx - diag_ctx + (br_ctx << 4);
+ }
+ av1_get_rate_dist_def_chroma(
+ txb_costs, &pqData, &coeff_ctx, blk_pos, bwl, tx_class, diag_ctx,
+ eob_rate, plane, tmp_sign[blk_pos], coeff_sign, n_states, &rd);
+ }
}
- int try_eob = sharpness == 0;
- av1_decide_states(prd, &rd, &pqData, limits, try_eob, rdmult, decision);
+ f_decide_states(prd, &rd, &pqData, n_states, limits, try_eob, rdmult,
+ decision);
// copy corresponding context from previous level buffer
- for (int state = 0; state < TOTALSTATES && scan_pos != first_scan_pos;
+ for (int state = 0; state < n_states && scan_pos != first_scan_pos;
state++) {
int prevId = decision[state].prevId;
if (prevId >= 0)
@@ -1489,7 +1771,7 @@
}
// update levels_buf
- for (int state = 0; state < TOTALSTATES && scan_pos != 0; state++) {
+ for (int state = 0; state < n_states && scan_pos != 0; state++) {
set_levels_buf(decision[state].prevId, decision[state].absLevel,
levels[state], scan, first_scan_pos, scan_pos, bwl,
sharpness);
@@ -1539,16 +1821,18 @@
}
}
-int av1_find_best_path_c(const struct tcq_node_t *trellis, const int16_t *scan,
- const int32_t *dequant, const qm_val_t *iqmatrix,
- const tran_low_t *tcoeff, int first_scan_pos,
- int log_scale, tran_low_t *qcoeff, tran_low_t *dqcoeff,
- int *min_rate, int64_t *min_cost) {
+int av1_find_best_path_c(const struct tcq_node_t *trellis, int n_states_log2,
+ const int16_t *scan, const int32_t *dequant,
+ const qm_val_t *iqmatrix, const tran_low_t *tcoeff,
+ int first_scan_pos, int log_scale, tran_low_t *qcoeff,
+ tran_low_t *dqcoeff, int *min_rate,
+ int64_t *min_cost) {
// Select best trellis state.
+ int n_states = 1 << n_states_log2;
int64_t min_path_cost = INT64_MAX;
int trel_min_rate = INT32_MAX;
int prev_id = -2;
- for (int state = 0; state < TOTALSTATES; state++) {
+ for (int state = 0; state < n_states; state++) {
const tcq_node_t *decision = &trellis[state];
if (decision->rdCost < min_path_cost) {
prev_id = state;
@@ -1563,7 +1847,8 @@
int dqv = dequant[0];
int dqv_ac = dequant[1];
for (; prev_id >= 0; scan_pos++) {
- const tcq_node_t *decision = &trellis[scan_pos * TOTALSTATES + prev_id];
+ const tcq_node_t *decision =
+ &trellis[(scan_pos << n_states_log2) + prev_id];
prev_id = decision->prevId;
int abs_level = decision->absLevel;
int blk_pos = scan[scan_pos];
@@ -1579,7 +1864,8 @@
}
} else {
for (; prev_id >= 0; scan_pos++) {
- const tcq_node_t *decision = &trellis[scan_pos * TOTALSTATES + prev_id];
+ const tcq_node_t *decision =
+ &trellis[(scan_pos << n_states_log2) + prev_id];
prev_id = decision->prevId;
int abs_level = decision->absLevel;
int blk_pos = scan[scan_pos];
@@ -1676,8 +1962,9 @@
// getting context from previous level buf, updating levels on current level
// buf. initialization all value by 0, since we update every position.
+ int n_states_log2 = cm->features.tcq_mode == TCQ_8ST ? 3 : 2;
int bufsize = (width + 4) * (height + 4) + TX_PAD_END;
- int mem_tcq_sz = sizeof(uint8_t) * bufsize * 2 * TOTALSTATES;
+ int mem_tcq_sz = sizeof(uint8_t) * bufsize * (2 << n_states_log2);
uint8_t *mem_tcq = (uint8_t *)malloc(mem_tcq_sz);
if (!mem_tcq) {
exit(1);
@@ -1691,10 +1978,10 @@
int si = eob - 1;
// populate trellis
assert(si < MAX_TRELLIS);
- tcq_node_t trellis[MAX_TRELLIS][TOTALSTATES];
+ tcq_node_t trellis[MAX_TRELLIS * TCQ_MAX_STATES];
// Ping-pong buffers for diagonal contexts.
- tcq_ctx_t tcq_ctx[2 * TOTALSTATES];
+ tcq_ctx_t tcq_ctx[2 * TCQ_MAX_STATES];
// Precalc block eob rate.
uint16_t block_eob_rate[MAX_TRELLIS];
@@ -1792,39 +2079,64 @@
}
#endif // CONFIG_TXFMBLK_LOGS || CONFIG_COEFF_LOGS
+ // Collect TCQ related parameters.
+ tcq_param_t param;
+ int log_scale = av1_get_tx_scale(tx_size) + 1;
+ param.n_states_log2 = n_states_log2;
+ param.n_states = 1 << param.n_states_log2;
+ param.plane = plane;
+ param.tx_size = tx_size;
+ param.tx_class = tx_class;
+ param.sharpness = sharpness;
+ param.rdmult = rdmult;
+ param.log_scale = log_scale;
+ param.scan = scan;
+ param.tmp_sign = xd->tmp_sign;
+ param.qcoeff = qcoeff;
+ param.tcoeff = tcoeff;
+ param.quant = quant;
+ param.dequant = dequant;
+ param.iqmatrix = iqmatrix;
+ param.block_eob_rate = block_eob_rate;
+ param.txb_ctx = txb_ctx;
+ param.txb_costs = txb_costs;
+
// Start of TCQ
int first_scan_pos = si;
- int log_scale = av1_get_tx_scale(tx_size) + 1;
- trellis_first_pos(first_scan_pos, plane, tx_size, tx_class, xd->tmp_sign,
- sharpness, &tcq_lev, trellis, qcoeff, rdmult, log_scale,
- scan, tcoeff, dequant, quant, iqmatrix, block_eob_rate,
- txb_ctx, txb_costs);
+ trellis_first_pos(¶m, first_scan_pos, &tcq_lev, trellis);
+
int scan_hi = first_scan_pos - 1;
if (scan_hi >= 0) {
if (plane == 0 && tx_class == TX_CLASS_2D) {
const int scan_lf_start = 9;
- while (scan_hi > scan_lf_start) {
- int blk_pos = scan[scan_hi];
- int row = blk_pos >> bwl;
- int col = blk_pos - (row << bwl);
- int inc = AOMMIN(height - 1 - row, col);
- int scan_lo = AOMMAX(scan_lf_start + 1, scan_hi - inc);
- trellis_loop_diagonal(scan_hi, scan_lo, 0, tx_size, TX_CLASS_2D, 0,
- sharpness, &tcq_lev, tcq_ctx, trellis, qcoeff,
- rdmult, log_scale, scan, tcoeff, dequant, quant,
- iqmatrix, block_eob_rate, txb_ctx, txb_costs);
- scan_hi = scan_lo - 1;
+ if (param.n_states == 4) {
+ while (scan_hi > scan_lf_start) {
+ int blk_pos = scan[scan_hi];
+ int row = blk_pos >> bwl;
+ int col = blk_pos - (row << bwl);
+ int inc = AOMMIN(height - 1 - row, col);
+ int scan_lo = AOMMAX(scan_lf_start + 1, scan_hi - inc);
+ trellis_loop_diagonal_st4(¶m, scan_hi, scan_lo, &tcq_lev, tcq_ctx,
+ trellis);
+ scan_hi = scan_lo - 1;
+ }
+ trellis_loop_lf_st4(¶m, scan_hi, 0, &tcq_lev, trellis);
+ } else { // n_states == 8
+ while (scan_hi > scan_lf_start) {
+ int blk_pos = scan[scan_hi];
+ int row = blk_pos >> bwl;
+ int col = blk_pos - (row << bwl);
+ int inc = AOMMIN(height - 1 - row, col);
+ int scan_lo = AOMMAX(scan_lf_start + 1, scan_hi - inc);
+ trellis_loop_diagonal_st8(¶m, scan_hi, scan_lo, &tcq_lev, tcq_ctx,
+ trellis);
+ scan_hi = scan_lo - 1;
+ }
+ trellis_loop_lf_st8(¶m, scan_hi, 0, &tcq_lev, trellis);
}
- trellis_loop_lf(scan_hi, 0, plane, tx_size, tx_class, xd->tmp_sign,
- sharpness, &tcq_lev, trellis, qcoeff, rdmult, log_scale,
- scan, tcoeff, dequant, quant, iqmatrix, block_eob_rate,
- txb_ctx, txb_costs);
} else {
- trellis_loop(first_scan_pos, scan_hi, 0, plane, tx_size, tx_class,
- xd->tmp_sign, sharpness, &tcq_lev, trellis, qcoeff, rdmult,
- log_scale, scan, tcoeff, dequant, quant, iqmatrix,
- block_eob_rate, txb_ctx, txb_costs);
+ trellis_loop(¶m, first_scan_pos, scan_hi, 0, &tcq_lev, trellis);
}
}
@@ -1833,8 +2145,8 @@
// find best path
int min_rate = INT32_MAX;
int64_t min_path_cost = INT64_MAX;
- eob = av1_find_best_path(&trellis[0][0], scan, dequant, iqmatrix, tcoeff,
- first_scan_pos, log_scale, qcoeff, dqcoeff,
+ eob = av1_find_best_path(trellis, n_states_log2, scan, dequant, iqmatrix,
+ tcoeff, first_scan_pos, log_scale, qcoeff, dqcoeff,
&min_rate, &min_path_cost);
#if CONFIG_CONTEXT_DERIVATION
diff --git a/av1/encoder/trellis_quant.h b/av1/encoder/trellis_quant.h
index e315689..30e5673 100644
--- a/av1/encoder/trellis_quant.h
+++ b/av1/encoder/trellis_quant.h
@@ -52,17 +52,38 @@
} prequant_t;
typedef struct tcq_rate_t {
- int32_t rate[2 * TOTALSTATES];
- int32_t rate_zero[TOTALSTATES];
+ int32_t rate[2 * TCQ_MAX_STATES];
+ int32_t rate_zero[TCQ_MAX_STATES];
int32_t rate_eob[2];
} tcq_rate_t;
typedef struct tcq_coeff_ctx_t {
- uint8_t coef[TOTALSTATES];
+ uint8_t coef[TCQ_MAX_STATES];
uint8_t coef_eob;
uint8_t pad[3];
} tcq_coeff_ctx_t;
+typedef struct tcq_param_t {
+ int n_states;
+ int n_states_log2;
+ int plane;
+ TX_SIZE tx_size;
+ TX_CLASS tx_class;
+ int sharpness;
+ int64_t rdmult;
+ int log_scale;
+ const int16_t *scan;
+ const int32_t *tmp_sign;
+ const tran_low_t *qcoeff;
+ const tran_low_t *tcoeff;
+ const int32_t *quant;
+ const int32_t *dequant;
+ const qm_val_t *iqmatrix;
+ const uint16_t *block_eob_rate;
+ const TXB_CTX *txb_ctx;
+ const LV_MAP_COEFF_COST *txb_costs;
+} tcq_param_t;
+
static AOM_FORCE_INLINE int get_low_range(int abs_qc, int lf) {
int base_levels = lf ? 6 : 4;
int parity = abs_qc & 1;
diff --git a/av1/encoder/x86/trellis_quant_avx2.c b/av1/encoder/x86/trellis_quant_avx2.c
index d1059ea..6563257 100644
--- a/av1/encoder/x86/trellis_quant_avx2.c
+++ b/av1/encoder/x86/trellis_quant_avx2.c
@@ -21,19 +21,36 @@
#include "aom_dsp/x86/synonyms.h"
#include "aom_dsp/x86/synonyms_avx2.h"
+// av1_decide_states_*() constants.
+static const int32_t kShuffle[8] = { 0, 2, 1, 3, 5, 7, 4, 6 };
+static const int32_t kPrevId[TCQ_MAX_STATES / 4][8] = {
+ { 0, 0 << 24, 0, 1 << 24, 0, 2 << 24, 0, 3 << 24 },
+ { 0, 4 << 24, 0, 5 << 24, 0, 6 << 24, 0, 7 << 24 },
+};
+
+// av1_calc_lf_ctx_*() constants.
+// Neighbor mask for calculating context sum (base/mid).
+#define M MAX_VAL_BR_CTX
+static const int8_t kNbrMask[4][32] = {
+ { 5, 5, 5, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // diag 0
+ M, M, 0, M, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+ { 0, 5, 5, 0, 5, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, // diag 1
+ 0, M, M, 0, 0, M, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+ { 0, 0, 5, 5, 0, 0, 5, 5, 5, 0, 0, 0, 0, 0, 0, 0, // diag 2
+ 0, 0, M, M, 0, 0, 0, M, 0, 0, 0, 0, 0, 0, 0, 0 },
+ { 0, 0, 0, 5, 5, 0, 0, 0, 5, 5, 5, 0, 0, 0, 0, 0, // diag 3
+ 0, 0, 0, M, M, 0, 0, 0, 0, M, 0, 0, 0, 0, 0, 0 },
+};
+static const int8_t kMaxCtx[16] = { 8, 6, 6, 4, 4, 4, 4, 4,
+ 4, 4, 4, 4, 4, 4, 4, 4 };
+static const int8_t kScanDiag[MAX_LF_SCAN] = { 0, 1, 1, 2, 2, 2, 3, 3, 3, 3 };
+
void av1_decide_states_avx2(const struct tcq_node_t *prev,
const struct tcq_rate_t *rd,
- const struct prequant_t *pq, int limits,
- int try_eob, int64_t rdmult,
+ const struct prequant_t *pq, int n_states,
+ int limits, int try_eob, int64_t rdmult,
struct tcq_node_t *decision) {
(void)limits;
- static const int32_t kShuffle[8] = { 0, 2, 1, 3, 5, 7, 4, 6 };
- static const int32_t kPrevId[TOTALSTATES / 4][8] = {
- { 0, 0 << 24, 0, 1 << 24, 0, 2 << 24, 0, 3 << 24 },
-#if MORESTATES
- { 0, 4 << 24, 0, 5 << 24, 0, 6 << 24, 0, 7 << 24 },
-#endif
- };
assert((rdmult >> 32) == 0);
assert(sizeof(tcq_node_t) == 16);
@@ -50,7 +67,10 @@
__m256i abslev0033 = _mm256_unpacklo_epi32(c_zero, abslev00223311);
__m256i abslev2211 = _mm256_unpackhi_epi32(c_zero, abslev00223311);
- for (int i = 0; i < TOTALSTATES / 4; i++) {
+ __m256i *out_a = (__m256i *)&decision[0];
+ __m256i *out_b = (__m256i *)&decision[n_states >> 1];
+
+ for (int i = 0; i < n_states >> 2; i++) {
// Load distortion.
__m256i dist = _mm256_lddqu_si256((__m256i *)&pq->deltaDist[0]);
dist = _mm256_slli_epi64(dist, RDDIV_BITS);
@@ -109,8 +129,8 @@
__m256i rdcost3164 = _mm256_shuffle_epi32(rdcost1346, 0x4E);
__m256i rate3164 = _mm256_shuffle_epi32(rate1346, 0x4E);
__m256i use_odd = _mm256_cmpgt_epi64(rdcost0257, rdcost3164);
- __m256i prev_id = _mm256_lddqu_si256((__m256i *)kPrevId[i]);
__m256i use_odd_1 = _mm256_slli_epi64(_mm256_srli_epi64(use_odd, 63), 56);
+ __m256i prev_id = _mm256_lddqu_si256((__m256i *)kPrevId[i]);
prev_id = _mm256_xor_si256(prev_id, use_odd_1);
__m256i rdcost_best = _mm256_blendv_epi8(rdcost0257, rdcost3164, use_odd);
__m256i rate_best = _mm256_blendv_epi8(rate0257, rate3164, use_odd);
@@ -141,16 +161,134 @@
info_best = _mm256_or_si256(info_best, prev_id);
__m256i info01 = _mm256_unpacklo_epi64(rdcost_best, info_best);
__m256i info23 = _mm256_unpackhi_epi64(rdcost_best, info_best);
-#if MORESTATES
- _mm256_storeu_si256((__m256i *)&decision[6 * i], info01);
- _mm256_storeu_si256((__m256i *)&decision[4 - (2 * i)], info23);
-#else
- _mm256_storeu_si256((__m256i *)&decision[0], info01);
- _mm256_storeu_si256((__m256i *)&decision[2], info23);
-#endif
+ _mm256_storeu_si256(out_a, info01);
+ _mm256_storeu_si256(out_b, info23);
+ out_a = (__m256i *)&decision[6];
+ out_b = (__m256i *)&decision[2];
}
}
+void av1_decide_states_st4_avx2(const struct tcq_node_t *prev,
+ const struct tcq_rate_t *rd,
+ const struct prequant_t *pq, int n_states,
+ int limits, int try_eob, int64_t rdmult,
+ struct tcq_node_t *decision) {
+ (void)limits;
+ (void)n_states;
+ assert(n_states == 4);
+ assert((rdmult >> 32) == 0);
+ assert(sizeof(tcq_node_t) == 16);
+
+ int i = 0;
+
+ __m256i c_rdmult = _mm256_set1_epi64x(rdmult);
+ __m256i c_round = _mm256_set1_epi64x(1 << (AV1_PROB_COST_SHIFT - 1));
+ __m256i c_zero = _mm256_setzero_si256();
+
+ // Gather absolute coeff level for 4 possible quant options.
+ __m128i abslev0123 = _mm_lddqu_si128((__m128i *)pq->absLevel);
+ __m256i abslev0231 =
+ _mm256_castsi128_si256(_mm_shuffle_epi32(abslev0123, 0x78));
+ __m256i abslev02023131 = _mm256_permute4x64_epi64(abslev0231, 0x50);
+ __m256i abslev00223311 = _mm256_shuffle_epi32(abslev02023131, 0x50);
+ __m256i abslev0033 = _mm256_unpacklo_epi32(c_zero, abslev00223311);
+ __m256i abslev2211 = _mm256_unpackhi_epi32(c_zero, abslev00223311);
+
+ // Load distortion.
+ __m256i dist = _mm256_lddqu_si256((__m256i *)&pq->deltaDist[0]);
+ dist = _mm256_slli_epi64(dist, RDDIV_BITS);
+ __m256i dist0033 = _mm256_permute4x64_epi64(dist, 0xF0);
+ __m256i dist2211 = _mm256_permute4x64_epi64(dist, 0x5A);
+
+ // Calc rate-distortion costs for each pair of even/odd quant.
+ // Separate candidates into even and odd quant decisions
+ // Even indexes: { 0, 2, 5, 7 }. Odd: { 1, 3, 4, 6 }.
+ __m256i rates = _mm256_lddqu_si256((__m256i *)&rd->rate[8 * i]);
+ __m256i permute_mask = _mm256_lddqu_si256((__m256i *)kShuffle);
+ __m256i rate02135746 = _mm256_permutevar8x32_epi32(rates, permute_mask);
+ __m256i rate0257 = _mm256_unpacklo_epi32(rate02135746, c_zero);
+ __m256i rate1346 = _mm256_unpackhi_epi32(rate02135746, c_zero);
+ __m256i rdcost0257 = _mm256_mul_epu32(c_rdmult, rate0257);
+ __m256i rdcost1346 = _mm256_mul_epu32(c_rdmult, rate1346);
+ rdcost0257 = _mm256_add_epi64(rdcost0257, c_round);
+ rdcost1346 = _mm256_add_epi64(rdcost1346, c_round);
+ rdcost0257 = _mm256_srli_epi64(rdcost0257, AV1_PROB_COST_SHIFT);
+ rdcost1346 = _mm256_srli_epi64(rdcost1346, AV1_PROB_COST_SHIFT);
+ rdcost0257 = _mm256_add_epi64(rdcost0257, dist0033);
+ rdcost1346 = _mm256_add_epi64(rdcost1346, dist2211);
+
+ // Calc rd-cost for zero quant.
+ __m256i ratezero =
+ _mm256_castsi128_si256(_mm_lddqu_si128((__m128i *)&rd->rate_zero[4 * i]));
+ ratezero = _mm256_permute4x64_epi64(ratezero, 0x50);
+ ratezero = _mm256_unpacklo_epi32(ratezero, c_zero);
+ __m256i rdcostzero = _mm256_mul_epu32(c_rdmult, ratezero);
+ rdcostzero = _mm256_add_epi64(rdcostzero, c_round);
+ rdcostzero = _mm256_srli_epi64(rdcostzero, AV1_PROB_COST_SHIFT);
+
+ // Add previous state rdCost to rdcostzero
+ __m256i state01 = _mm256_lddqu_si256((__m256i *)&prev[4 * i]);
+ __m256i state23 = _mm256_lddqu_si256((__m256i *)&prev[4 * i + 2]);
+ __m256i state02 = _mm256_permute2x128_si256(state01, state23, 0x20);
+ __m256i state13 = _mm256_permute2x128_si256(state01, state23, 0x31);
+ __m256i prevrd0123 = _mm256_unpacklo_epi64(state02, state13);
+ __m256i prevrate0123 = _mm256_unpackhi_epi64(state02, state13);
+ prevrate0123 = _mm256_slli_epi64(prevrate0123, 32);
+ prevrate0123 = _mm256_srli_epi64(prevrate0123, 32);
+
+ // Compare rd costs (Zero vs Even).
+ __m256i use_zero = _mm256_cmpgt_epi64(rdcost0257, rdcostzero);
+ rdcost0257 = _mm256_blendv_epi8(rdcost0257, rdcostzero, use_zero);
+ rate0257 = _mm256_blendv_epi8(rate0257, ratezero, use_zero);
+ __m256i abslev_even = _mm256_andnot_si256(use_zero, abslev0033);
+
+ // Add previous state rdCost to current rdcost
+ rdcost0257 = _mm256_add_epi64(rdcost0257, prevrd0123);
+ rdcost1346 = _mm256_add_epi64(rdcost1346, prevrd0123);
+ rate0257 = _mm256_add_epi64(rate0257, prevrate0123);
+ rate1346 = _mm256_add_epi64(rate1346, prevrate0123);
+
+ // Compare rd costs (Even vs Odd).
+ __m256i rdcost3164 = _mm256_shuffle_epi32(rdcost1346, 0x4E);
+ __m256i rate3164 = _mm256_shuffle_epi32(rate1346, 0x4E);
+ __m256i use_odd = _mm256_cmpgt_epi64(rdcost0257, rdcost3164);
+ __m256i use_odd_1 = _mm256_slli_epi64(_mm256_srli_epi64(use_odd, 63), 56);
+ __m256i prev_id = _mm256_lddqu_si256((__m256i *)kPrevId[i]);
+ prev_id = _mm256_xor_si256(prev_id, use_odd_1);
+ __m256i rdcost_best = _mm256_blendv_epi8(rdcost0257, rdcost3164, use_odd);
+ __m256i rate_best = _mm256_blendv_epi8(rate0257, rate3164, use_odd);
+ __m256i abslev_best = _mm256_blendv_epi8(abslev_even, abslev2211, use_odd);
+
+ // Compare rd costs (best vs new eob).
+ __m256i rate_eob = _mm256_castsi128_si256(_mm_loadu_si64(rd->rate_eob));
+ rate_eob = _mm256_unpacklo_epi32(rate_eob, c_zero);
+ __m256i rdcost_eob = _mm256_mul_epu32(c_rdmult, rate_eob);
+ rdcost_eob = _mm256_add_epi64(rdcost_eob, c_round);
+ rdcost_eob = _mm256_srli_epi64(rdcost_eob, AV1_PROB_COST_SHIFT);
+ __m256i dist_eob = _mm256_unpacklo_epi64(dist0033, dist2211);
+ rdcost_eob = _mm256_add_epi64(rdcost_eob, dist_eob);
+ __m128i mask_eob0 = _mm_set1_epi64x((int64_t)-try_eob);
+ __m256i mask_eob = _mm256_inserti128_si256(c_zero, mask_eob0, 0);
+ __m256i use_eob = _mm256_cmpgt_epi64(rdcost_best, rdcost_eob);
+ use_eob = _mm256_and_si256(use_eob, mask_eob);
+ __m256i use_eob_1 = _mm256_slli_epi64(use_eob, 56);
+ prev_id = _mm256_or_si256(prev_id, use_eob_1);
+ rdcost_best = _mm256_blendv_epi8(rdcost_best, rdcost_eob, use_eob);
+ rate_best = _mm256_blendv_epi8(rate_best, rate_eob, use_eob);
+ __m256i abslev_eob = _mm256_unpacklo_epi64(abslev0033, abslev2211);
+ abslev_best = _mm256_blendv_epi8(abslev_best, abslev_eob, use_eob);
+
+ // Pack and store state info.
+ __m256i info_best = _mm256_or_si256(rate_best, abslev_best);
+ info_best = _mm256_or_si256(info_best, prev_id);
+ __m256i info01 = _mm256_unpacklo_epi64(rdcost_best, info_best);
+ __m256i info23 = _mm256_unpackhi_epi64(rdcost_best, info_best);
+ __m256i *out_a = (__m256i *)&decision[0];
+ __m256i *out_b = (__m256i *)&decision[2];
+ _mm256_storeu_si256(out_a, info01);
+ _mm256_storeu_si256(out_b, info23);
+}
+
void av1_pre_quant_avx2(tran_low_t tqc, struct prequant_t *pqData,
const int32_t *quant_ptr, int dqv, int log_scale,
int scan_pos) {
@@ -197,10 +335,10 @@
_mm256_storeu_si256((__m256i *)pqData->deltaDist, dist);
}
-void av1_update_states_avx2(tcq_node_t *decision, int scan_idx,
+void av1_update_states_avx2(tcq_node_t *decision, int scan_idx, int n_states,
const struct tcq_ctx_t *cur_ctx,
struct tcq_ctx_t *nxt_ctx) {
- for (int i = 0; i < TOTALSTATES; i++) {
+ for (int i = 0; i < n_states; i++) {
int prevId = decision[i].prevId;
int absLevel = decision[i].absLevel;
if (prevId >= 0) {
@@ -392,7 +530,7 @@
const struct prequant_t *pq,
const tcq_coeff_ctx_t *coeff_ctx,
int blk_pos, int bwl, TX_CLASS tx_class,
- int diag_ctx, int eob_rate,
+ int diag_ctx, int eob_rate, int n_states,
struct tcq_rate_t *rd) {
(void)bwl;
const int32_t(*cost_zero)[SIG_COEF_CONTEXTS] = txb_costs->base_cost_zero;
@@ -419,11 +557,9 @@
__m256i ratez_0123 = _mm256_unpacklo_epi64(ratez_dq0, ratez_dq1);
_mm_storeu_si128((__m128i *)&rd->rate_zero[0],
_mm256_castsi256_si128(ratez_0123));
-#if MORESTATES
__m256i ratez_4567 = _mm256_unpackhi_epi64(ratez_dq0, ratez_dq1);
_mm_storeu_si128((__m128i *)&rd->rate_zero[4],
_mm256_castsi256_si128(ratez_4567));
-#endif
// Calc coeff_base rate.
int idx = AOMMIN(pq->qIdx - 1, 4);
@@ -432,7 +568,7 @@
__m256i base_ctx = _mm256_slli_epi16(ctx16, 12);
base_ctx = _mm256_srli_epi16(base_ctx, 12);
base_ctx = _mm256_add_epi16(base_ctx, diag);
- for (int i = 0; i < TOTALSTATES / 4; i++) {
+ for (int i = 0; i < (n_states >> 2); i++) {
int ctx0 = _mm256_extract_epi16(base_ctx, 0);
int ctx1 = _mm256_extract_epi16(base_ctx, 1);
int ctx2 = _mm256_extract_epi16(base_ctx, 2);
@@ -460,7 +596,7 @@
// Calc coeff mid and high range cost.
if (idx > 0) {
- for (int i = 0; i < TOTALSTATES; i++) {
+ for (int i = 0; i < n_states; i++) {
int a0 = i & 2 ? 1 : 0;
int a1 = a0 + 2;
int mid_cost0 = get_mid_cost_def(absLevel[a0], coeff_ctx->coef[i],
@@ -479,23 +615,96 @@
}
}
-void av1_calc_lf_ctx_avx2(const struct tcq_lf_ctx_t *lf_ctx, int scan_pos,
- struct tcq_coeff_ctx_t *coeff_ctx) {
-#define M MAX_VAL_BR_CTX
- // Neighbor mask for calculating context sum (base/mid).
- static const int8_t kNbrMask[4][32] = {
- { 5, 5, 5, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // diag 0
- M, M, 0, M, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
- { 0, 5, 5, 0, 5, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, // diag 1
- 0, M, M, 0, 0, M, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
- { 0, 0, 5, 5, 0, 0, 5, 5, 5, 0, 0, 0, 0, 0, 0, 0, // diag 2
- 0, 0, M, M, 0, 0, 0, M, 0, 0, 0, 0, 0, 0, 0, 0 },
- { 0, 0, 0, 5, 5, 0, 0, 0, 5, 5, 5, 0, 0, 0, 0, 0, // diag 3
- 0, 0, 0, M, M, 0, 0, 0, 0, M, 0, 0, 0, 0, 0, 0 },
- };
- static const int8_t kMaxCtx[16] = { 8, 6, 6, 4, 4, 4, 4, 4,
- 4, 4, 4, 4, 4, 4, 4, 4 };
- static const int8_t kScanDiag[MAX_LF_SCAN] = { 0, 1, 1, 2, 2, 2, 3, 3, 3, 3 };
+void av1_get_rate_dist_def_luma_st4_avx2(
+ const struct LV_MAP_COEFF_COST *txb_costs, const struct prequant_t *pq,
+ const tcq_coeff_ctx_t *coeff_ctx, int blk_pos, int bwl, TX_CLASS tx_class,
+ int diag_ctx, int eob_rate, int n_states, struct tcq_rate_t *rd) {
+ (void)bwl;
+ assert(n_states == 4);
+ n_states = 4;
+
+ const int32_t(*cost_zero)[SIG_COEF_CONTEXTS] = txb_costs->base_cost_zero;
+ const uint16_t(*cost_low_tbl)[SIG_COEF_CONTEXTS][DQ_CTXS][2] =
+ txb_costs->base_cost_low_tbl;
+ const uint16_t(*cost_eob_tbl)[SIG_COEF_CONTEXTS_EOB][2] =
+ txb_costs->base_eob_cost_tbl;
+ const tran_low_t *absLevel = pq->absLevel;
+
+ // Calc zero coeff costs.
+ __m256i zero = _mm256_setzero_si256();
+ __m256i cost_zero_dq0 =
+ _mm256_lddqu_si256((__m256i *)&cost_zero[0][diag_ctx]);
+ __m256i cost_zero_dq1 =
+ _mm256_lddqu_si256((__m256i *)&cost_zero[1][diag_ctx]);
+
+ __m256i coef_ctx = _mm256_castsi128_si256(_mm_loadu_si64(&coeff_ctx->coef));
+ __m256i ctx16 = _mm256_unpacklo_epi8(coef_ctx, zero);
+ __m256i ctx = _mm256_shuffle_epi32(ctx16, 0xD8);
+ __m256i ctx_dq0 = _mm256_unpacklo_epi16(ctx, zero);
+ __m256i ctx_dq1 = _mm256_unpackhi_epi16(ctx, zero);
+ __m256i ratez_dq0 = _mm256_permutevar8x32_epi32(cost_zero_dq0, ctx_dq0);
+ __m256i ratez_dq1 = _mm256_permutevar8x32_epi32(cost_zero_dq1, ctx_dq1);
+ __m256i ratez_0123 = _mm256_unpacklo_epi64(ratez_dq0, ratez_dq1);
+ _mm_storeu_si128((__m128i *)&rd->rate_zero[0],
+ _mm256_castsi256_si128(ratez_0123));
+
+ // Calc coeff_base rate.
+ int idx = AOMMIN(pq->qIdx - 1, 4);
+ __m128i c_zero = _mm_setzero_si128();
+ __m256i diag = _mm256_set1_epi16(diag_ctx);
+ __m256i base_ctx = _mm256_slli_epi16(ctx16, 12);
+ base_ctx = _mm256_srli_epi16(base_ctx, 12);
+ base_ctx = _mm256_add_epi16(base_ctx, diag);
+ for (int i = 0; i < (n_states >> 2); i++) {
+ int ctx0 = _mm256_extract_epi16(base_ctx, 0);
+ int ctx1 = _mm256_extract_epi16(base_ctx, 1);
+ int ctx2 = _mm256_extract_epi16(base_ctx, 2);
+ int ctx3 = _mm256_extract_epi16(base_ctx, 3);
+ base_ctx = _mm256_bsrli_epi128(base_ctx, 8);
+ __m128i rate_01 = _mm_loadu_si64(&cost_low_tbl[idx][ctx0][0]);
+ __m128i rate_23 = _mm_loadu_si64(&cost_low_tbl[idx][ctx1][0]);
+ __m128i rate_45 = _mm_loadu_si64(&cost_low_tbl[idx][ctx2][1]);
+ __m128i rate_67 = _mm_loadu_si64(&cost_low_tbl[idx][ctx3][1]);
+ __m128i rate_0123 = _mm_unpacklo_epi32(rate_01, rate_23);
+ __m128i rate_4567 = _mm_unpacklo_epi32(rate_45, rate_67);
+ rate_0123 = _mm_unpacklo_epi16(rate_0123, c_zero);
+ rate_4567 = _mm_unpacklo_epi16(rate_4567, c_zero);
+ _mm_storeu_si128((__m128i *)&rd->rate[8 * i], rate_0123);
+ _mm_storeu_si128((__m128i *)&rd->rate[8 * i + 4], rate_4567);
+ }
+
+ // Calc coeff/eob cost.
+ int eob_ctx = coeff_ctx->coef_eob;
+ __m128i rate_eob_coef = _mm_loadu_si64(&cost_eob_tbl[idx][eob_ctx][0]);
+ rate_eob_coef = _mm_unpacklo_epi16(rate_eob_coef, c_zero);
+ __m128i rate_eob_position = _mm_set1_epi32(eob_rate);
+ __m128i rate_eob = _mm_add_epi32(rate_eob_coef, rate_eob_position);
+ _mm_storeu_si64(&rd->rate_eob[0], rate_eob);
+
+ // Calc coeff mid and high range cost.
+ if (idx > 0) {
+ for (int i = 0; i < n_states; i++) {
+ int a0 = i & 2 ? 1 : 0;
+ int a1 = a0 + 2;
+ int mid_cost0 = get_mid_cost_def(absLevel[a0], coeff_ctx->coef[i],
+ txb_costs, 0, 0, 0);
+ int mid_cost1 = get_mid_cost_def(absLevel[a1], coeff_ctx->coef[i],
+ txb_costs, 0, 0, 0);
+ rd->rate[2 * i] += mid_cost0;
+ rd->rate[2 * i + 1] += mid_cost1;
+ }
+ int eob_mid_cost0 = get_mid_cost_eob(blk_pos, 0, 0, absLevel[0], 0, 0,
+ txb_costs, tx_class, 0, 0);
+ int eob_mid_cost1 = get_mid_cost_eob(blk_pos, 0, 0, absLevel[2], 0, 0,
+ txb_costs, tx_class, 0, 0);
+ rd->rate_eob[0] += eob_mid_cost0;
+ rd->rate_eob[1] += eob_mid_cost1;
+ }
+}
+
+void av1_calc_lf_ctx_st4_avx2(const struct tcq_lf_ctx_t *lf_ctx, int scan_pos,
+ struct tcq_coeff_ctx_t *coeff_ctx) {
+ int n_states = 4;
int diag = kScanDiag[scan_pos];
__m256i zero = _mm256_setzero_si256();
@@ -503,7 +712,7 @@
__m256i base_mask = _mm256_permute2x128_si256(nbr_mask, nbr_mask, 0);
__m256i mid_mask = _mm256_permute2x128_si256(nbr_mask, nbr_mask, 0x11);
- for (int st = 0; st < TOTALSTATES; st += 4) {
+ for (int st = 0; st < n_states; st += 4) {
// Load previously decoded LF context values.
__m256i last01 = _mm256_lddqu_si256((__m256i *)&lf_ctx[st]);
__m256i last23 = _mm256_lddqu_si256((__m256i *)&lf_ctx[st + 2]);
@@ -546,17 +755,86 @@
ctx16 = _mm256_add_epi16(ctx16, mid_ctx_offset);
__m128i ctx8 = _mm256_castsi256_si128(ctx16);
ctx8 = _mm_packus_epi16(ctx8, ctx8);
- _mm_storeu_si64(&coeff_ctx->coef[st], ctx8);
+#if 1
+ // Older compilers don't implement _mm_storeu_si32()
+ _mm_store_ss((float *)&coeff_ctx->coef[st], _mm_castsi128_ps(ctx8));
+#else
+ _mm_storeu_si32(&coeff_ctx->coef[st], ctx8);
+#endif
}
}
-void av1_update_lf_ctx_avx2(const struct tcq_node_t *decision,
+void av1_calc_lf_ctx_st8_avx2(const struct tcq_lf_ctx_t *lf_ctx, int scan_pos,
+ struct tcq_coeff_ctx_t *coeff_ctx) {
+ int n_states = 8;
+
+ int diag = kScanDiag[scan_pos];
+ __m256i zero = _mm256_setzero_si256();
+ __m256i nbr_mask = _mm256_lddqu_si256((__m256i *)kNbrMask[diag]);
+ __m256i base_mask = _mm256_permute2x128_si256(nbr_mask, nbr_mask, 0);
+ __m256i mid_mask = _mm256_permute2x128_si256(nbr_mask, nbr_mask, 0x11);
+
+ for (int st = 0; st < n_states; st += 4) {
+ // Load previously decoded LF context values.
+ __m256i last01 = _mm256_lddqu_si256((__m256i *)&lf_ctx[st]);
+ __m256i last23 = _mm256_lddqu_si256((__m256i *)&lf_ctx[st + 2]);
+
+ // Calc base ctx neighbor sum.
+ __m256i base01 = _mm256_min_epu8(last01, base_mask);
+ __m256i base23 = _mm256_min_epu8(last23, base_mask);
+ __m256i base01_sum = _mm256_sad_epu8(base01, zero);
+ __m256i base23_sum = _mm256_sad_epu8(base23, zero);
+ __m256i base_sum =
+ _mm256_hadd_epi32(base01_sum, base23_sum); // B0 B0 B2 B2 B1 B1 B3 B3
+
+ // Calc mid ctx neighbor sum.
+ __m256i mid01 = _mm256_min_epu8(last01, mid_mask);
+ __m256i mid23 = _mm256_min_epu8(last23, mid_mask);
+ __m256i mid01_sum = _mm256_sad_epu8(mid01, zero);
+ __m256i mid23_sum = _mm256_sad_epu8(mid23, zero);
+ __m256i mid_sum =
+ _mm256_hadd_epi32(mid01_sum, mid23_sum); // M0 M0 M2 M2 M1 M1 M3 M3
+
+ // Context calc; combine and reduce to 8 bits.
+ __m256i base_mid =
+ _mm256_hadd_epi32(base_sum, mid_sum); // B0B2 M0M2 B1B3 M1M3
+ base_mid = _mm256_hadd_epi16(
+ base_mid, zero); // reduce to 16 bits B0B2 M0M2 - - B1B3 M1M3 - -
+ base_mid = _mm256_avg_epu16(base_mid, zero); // x = (x + 1) >> 1
+ base_mid = _mm256_shufflelo_epi16(
+ base_mid, 0xD8); // shuffle B0M0 B2M2 - - B1M1 B3M3 - -
+ base_mid = _mm256_permute4x64_epi64(
+ base_mid, 0xD8); // pack into lower half: B0M0 B2M2 B1M1 B3M3
+ base_mid = _mm256_shuffle_epi32(base_mid, 0xD8); // B0M0 B1M1 B2M2 B3M3
+ __m256i six = _mm256_set1_epi16(6);
+ __m256i mid = _mm256_min_epi16(base_mid, six);
+ __m256i mid_sh4 = _mm256_slli_epi16(mid, 4);
+ __m256i base_max = _mm256_set1_epi16(kMaxCtx[scan_pos]);
+ __m256i base = _mm256_min_epi16(base_mid, base_max);
+ base_mid = _mm256_blend_epi16(base, mid_sh4, 0xAA);
+ __m256i ctx16 = _mm256_hadd_epi16(base_mid, base_mid);
+ __m256i mid_ctx_offset = _mm256_set1_epi16((scan_pos == 0) ? 0 : (7 << 4));
+ ctx16 = _mm256_add_epi16(ctx16, mid_ctx_offset);
+ __m128i ctx8 = _mm256_castsi256_si128(ctx16);
+ ctx8 = _mm_packus_epi16(ctx8, ctx8);
+#if 1
+ // Older compilers don't implement _mm_storeu_si32()
+ _mm_store_ss((float *)&coeff_ctx->coef[st], _mm_castsi128_ps(ctx8));
+#else
+ _mm_storeu_si32(&coeff_ctx->coef[st], ctx8);
+#endif
+ }
+}
+
+void av1_update_lf_ctx_avx2(const struct tcq_node_t *decision, int n_states,
struct tcq_lf_ctx_t *lf_ctx) {
- __m256i upd_last_a;
- __m256i upd_last_b;
- __m256i upd_last_c;
- __m256i upd_last_d;
- for (int st = 0; st < TOTALSTATES; st += 2) {
+ __m256i c_zero = _mm256_setzero_si256();
+ __m256i upd_last_a = c_zero;
+ __m256i upd_last_b = c_zero;
+ __m256i upd_last_c = c_zero;
+ __m256i upd_last_d = c_zero;
+
+ for (int st = 0; st < n_states; st += 2) {
int absLevel0 = decision[st].absLevel;
int prevId0 = decision[st].prevId;
int absLevel1 = decision[st + 1].absLevel;
@@ -580,26 +858,26 @@
upd_last_b = upd_last_a;
upd_last_a = upd01;
}
-#if MORESTATES
- _mm256_storeu_si256((__m256i *)lf_ctx[0].last, upd_last_d);
- _mm256_storeu_si256((__m256i *)lf_ctx[2].last, upd_last_c);
- _mm256_storeu_si256((__m256i *)lf_ctx[4].last, upd_last_b);
- _mm256_storeu_si256((__m256i *)lf_ctx[6].last, upd_last_a);
-#else
- (void)upd_last_d;
- (void)upd_last_c;
- _mm256_storeu_si256((__m256i *)lf_ctx[0].last, upd_last_b);
- _mm256_storeu_si256((__m256i *)lf_ctx[2].last, upd_last_a);
-#endif
+ if (n_states == 4) {
+ (void)upd_last_d;
+ (void)upd_last_c;
+ _mm256_storeu_si256((__m256i *)lf_ctx[0].last, upd_last_b);
+ _mm256_storeu_si256((__m256i *)lf_ctx[2].last, upd_last_a);
+ } else {
+ _mm256_storeu_si256((__m256i *)lf_ctx[0].last, upd_last_d);
+ _mm256_storeu_si256((__m256i *)lf_ctx[2].last, upd_last_c);
+ _mm256_storeu_si256((__m256i *)lf_ctx[4].last, upd_last_b);
+ _mm256_storeu_si256((__m256i *)lf_ctx[6].last, upd_last_a);
+ }
}
void av1_get_rate_dist_lf_luma_avx2(const struct LV_MAP_COEFF_COST *txb_costs,
const struct prequant_t *pq,
const struct tcq_coeff_ctx_t *coeff_ctx,
int blk_pos, int diag_ctx, int eob_rate,
- int dc_sign_ctx, int32_t *tmp_sign, int bwl,
- TX_CLASS tx_class, int coeff_sign,
- struct tcq_rate_t *rd) {
+ int dc_sign_ctx, const int32_t *tmp_sign,
+ int bwl, TX_CLASS tx_class, int coeff_sign,
+ int n_states, struct tcq_rate_t *rd) {
#define Z -1
static const int8_t kShuf[2][32] = {
{ 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15,
@@ -636,18 +914,14 @@
ratez = _mm256_permute4x64_epi64(ratez, 0x88);
__m256i shuf1 = _mm256_lddqu_si256((__m256i *)kShuf[1]);
ratez = _mm256_shuffle_epi8(ratez, shuf1);
-#if MORESTATES
_mm256_storeu_si256((__m256i *)&rd->rate_zero[0], ratez);
-#else
- _mm_storeu_si128((__m128i *)&rd->rate_zero[0], _mm256_castsi256_si128(ratez));
-#endif
// Calc coeff_base rate.
int idx = AOMMIN(pq->qIdx - 1, 8);
__m128i c_zero = _mm_setzero_si128();
__m256i diag = _mm256_set1_epi8(diag_ctx);
base_ctx = _mm256_add_epi8(base_ctx, diag);
- for (int i = 0; i < TOTALSTATES / 4; i++) {
+ for (int i = 0; i < (n_states >> 2); i++) {
int ctx0 = _mm256_extract_epi8(base_ctx, 0);
int ctx1 = _mm256_extract_epi8(base_ctx, 1);
int ctx2 = _mm256_extract_epi8(base_ctx, 2);
@@ -680,7 +954,7 @@
const bool dc_ver = (row == 0) && tx_class == TX_CLASS_VERT;
const bool is_dc_coeff = dc_2dtx || dc_hor || dc_ver;
if (is_dc_coeff) {
- for (int i = 0; i < TOTALSTATES; i++) {
+ for (int i = 0; i < n_states; i++) {
int a0 = i & 2 ? 1 : 0;
int a1 = a0 + 2;
int mid_cost0 = get_mid_cost_lf_dc(blk_pos, absLevel[a0], coeff_sign,
@@ -702,7 +976,134 @@
rd->rate_eob[0] += eob_mid_cost0;
rd->rate_eob[1] += eob_mid_cost1;
} else if (idx > 4) {
- for (int i = 0; i < TOTALSTATES; i++) {
+ for (int i = 0; i < n_states; i++) {
+ int a0 = i & 2 ? 1 : 0;
+ int a1 = a0 + 2;
+ int mid_cost0 =
+ get_mid_cost_lf(absLevel[a0], coeff_ctx->coef[i], txb_costs, plane);
+ int mid_cost1 =
+ get_mid_cost_lf(absLevel[a1], coeff_ctx->coef[i], txb_costs, plane);
+ rd->rate[2 * i] += mid_cost0;
+ rd->rate[2 * i + 1] += mid_cost1;
+ }
+ int t_sign = tmp_sign[blk_pos];
+ int eob_mid_cost0 =
+ get_mid_cost_eob(blk_pos, 1, 0, absLevel[0], coeff_sign, dc_sign_ctx,
+ txb_costs, tx_class, t_sign, 0);
+ int eob_mid_cost1 =
+ get_mid_cost_eob(blk_pos, 1, 0, absLevel[2], coeff_sign, dc_sign_ctx,
+ txb_costs, tx_class, t_sign, 0);
+ rd->rate_eob[0] += eob_mid_cost0;
+ rd->rate_eob[1] += eob_mid_cost1;
+ }
+}
+
+void av1_get_rate_dist_lf_luma_st4_avx2(
+ const struct LV_MAP_COEFF_COST *txb_costs, const struct prequant_t *pq,
+ const struct tcq_coeff_ctx_t *coeff_ctx, int blk_pos, int diag_ctx,
+ int eob_rate, int dc_sign_ctx, const int32_t *tmp_sign, int bwl,
+ TX_CLASS tx_class, int coeff_sign, int n_states, struct tcq_rate_t *rd) {
+ assert(n_states == 4);
+ n_states = 4;
+#define Z -1
+ static const int8_t kShuf[2][32] = {
+ { 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15,
+ 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15 },
+ { 0, 8, Z, Z, 1, 9, Z, Z, 2, 10, Z, Z, 3, 11, Z, Z,
+ 4, 12, Z, Z, 5, 13, Z, Z, 6, 14, Z, Z, 7, 15, Z, Z }
+ };
+ const uint16_t(*cost_zero)[LF_SIG_COEF_CONTEXTS] =
+ txb_costs->base_lf_cost_zero;
+ const uint16_t(*cost_low_tbl)[LF_SIG_COEF_CONTEXTS][DQ_CTXS][2] =
+ txb_costs->base_lf_cost_low_tbl;
+ const uint16_t(*cost_eob_tbl)[SIG_COEF_CONTEXTS_EOB][2] =
+ txb_costs->base_lf_eob_cost_tbl;
+ const tran_low_t *absLevel = pq->absLevel;
+ const int plane = 0;
+
+ // Calc zero coeff costs.
+ __m256i cost_zero_dq0 =
+ _mm256_lddqu_si256((__m256i *)&cost_zero[0][diag_ctx]);
+ __m256i cost_zero_dq1 =
+ _mm256_lddqu_si256((__m256i *)&cost_zero[1][diag_ctx]);
+ __m256i shuf = _mm256_lddqu_si256((__m256i *)kShuf[0]);
+ cost_zero_dq0 = _mm256_shuffle_epi8(cost_zero_dq0, shuf);
+ cost_zero_dq1 = _mm256_shuffle_epi8(cost_zero_dq1, shuf);
+ __m256i cost_dq0 = _mm256_permute4x64_epi64(cost_zero_dq0, 0xD8);
+ __m256i cost_dq1 = _mm256_permute4x64_epi64(cost_zero_dq1, 0xD8);
+ __m256i ctx = _mm256_castsi128_si256(_mm_loadu_si64(&coeff_ctx->coef));
+ __m256i fifteen = _mm256_set1_epi8(15);
+ __m256i base_ctx = _mm256_and_si256(ctx, fifteen);
+ __m256i base_ctx1 = _mm256_permute4x64_epi64(base_ctx, 0);
+ __m256i ratez_dq0 = _mm256_shuffle_epi8(cost_dq0, base_ctx1);
+ __m256i ratez_dq1 = _mm256_shuffle_epi8(cost_dq1, base_ctx1);
+ __m256i ratez = _mm256_blend_epi16(ratez_dq0, ratez_dq1, 0xAA);
+ ratez = _mm256_permute4x64_epi64(ratez, 0x88);
+ __m256i shuf1 = _mm256_lddqu_si256((__m256i *)kShuf[1]);
+ ratez = _mm256_shuffle_epi8(ratez, shuf1);
+ _mm256_storeu_si256((__m256i *)&rd->rate_zero[0], ratez);
+
+ // Calc coeff_base rate.
+ int idx = AOMMIN(pq->qIdx - 1, 8);
+ __m128i c_zero = _mm_setzero_si128();
+ __m256i diag = _mm256_set1_epi8(diag_ctx);
+ base_ctx = _mm256_add_epi8(base_ctx, diag);
+ for (int i = 0; i < (n_states >> 2); i++) {
+ int ctx0 = _mm256_extract_epi8(base_ctx, 0);
+ int ctx1 = _mm256_extract_epi8(base_ctx, 1);
+ int ctx2 = _mm256_extract_epi8(base_ctx, 2);
+ int ctx3 = _mm256_extract_epi8(base_ctx, 3);
+ base_ctx = _mm256_bsrli_epi128(base_ctx, 4);
+ __m128i rate_01 = _mm_loadu_si64(&cost_low_tbl[idx][ctx0][0]);
+ __m128i rate_23 = _mm_loadu_si64(&cost_low_tbl[idx][ctx1][0]);
+ __m128i rate_45 = _mm_loadu_si64(&cost_low_tbl[idx][ctx2][1]);
+ __m128i rate_67 = _mm_loadu_si64(&cost_low_tbl[idx][ctx3][1]);
+ __m128i rate_0123 = _mm_unpacklo_epi32(rate_01, rate_23);
+ __m128i rate_4567 = _mm_unpacklo_epi32(rate_45, rate_67);
+ rate_0123 = _mm_unpacklo_epi16(rate_0123, c_zero);
+ rate_4567 = _mm_unpacklo_epi16(rate_4567, c_zero);
+ _mm_storeu_si128((__m128i *)&rd->rate[8 * i], rate_0123);
+ _mm_storeu_si128((__m128i *)&rd->rate[8 * i + 4], rate_4567);
+ }
+
+ // Calc coeff/eob cost.
+ int eob_ctx = coeff_ctx->coef_eob;
+ __m128i rate_eob_coef = _mm_loadu_si64(&cost_eob_tbl[idx][eob_ctx][0]);
+ rate_eob_coef = _mm_unpacklo_epi16(rate_eob_coef, c_zero);
+ __m128i rate_eob_position = _mm_set1_epi32(eob_rate);
+ __m128i rate_eob = _mm_add_epi32(rate_eob_coef, rate_eob_position);
+ _mm_storeu_si64(&rd->rate_eob[0], rate_eob);
+
+ const int row = blk_pos >> bwl;
+ const int col = blk_pos - (row << bwl);
+ const bool dc_2dtx = (blk_pos == 0);
+ const bool dc_hor = (col == 0) && tx_class == TX_CLASS_HORIZ;
+ const bool dc_ver = (row == 0) && tx_class == TX_CLASS_VERT;
+ const bool is_dc_coeff = dc_2dtx || dc_hor || dc_ver;
+ if (is_dc_coeff) {
+ for (int i = 0; i < n_states; i++) {
+ int a0 = i & 2 ? 1 : 0;
+ int a1 = a0 + 2;
+ int mid_cost0 = get_mid_cost_lf_dc(blk_pos, absLevel[a0], coeff_sign,
+ coeff_ctx->coef[i], dc_sign_ctx,
+ txb_costs, tmp_sign, plane);
+ int mid_cost1 = get_mid_cost_lf_dc(blk_pos, absLevel[a1], coeff_sign,
+ coeff_ctx->coef[i], dc_sign_ctx,
+ txb_costs, tmp_sign, plane);
+ rd->rate[2 * i] += mid_cost0;
+ rd->rate[2 * i + 1] += mid_cost1;
+ }
+ int t_sign = tmp_sign[blk_pos];
+ int eob_mid_cost0 =
+ get_mid_cost_eob(blk_pos, 1, 1, absLevel[0], coeff_sign, dc_sign_ctx,
+ txb_costs, tx_class, t_sign, 0);
+ int eob_mid_cost1 =
+ get_mid_cost_eob(blk_pos, 1, 1, absLevel[2], coeff_sign, dc_sign_ctx,
+ txb_costs, tx_class, t_sign, 0);
+ rd->rate_eob[0] += eob_mid_cost0;
+ rd->rate_eob[1] += eob_mid_cost1;
+ } else if (idx > 4) {
+ for (int i = 0; i < n_states; i++) {
int a0 = i & 2 ? 1 : 0;
int a1 = a0 + 2;
int mid_cost0 =
@@ -728,9 +1129,10 @@
const struct prequant_t *pq,
const struct tcq_coeff_ctx_t *coeff_ctx,
int blk_pos, int diag_ctx, int eob_rate,
- int dc_sign_ctx, int32_t *tmp_sign,
+ int dc_sign_ctx, const int32_t *tmp_sign,
int bwl, TX_CLASS tx_class, int plane,
- int coeff_sign, struct tcq_rate_t *rd) {
+ int coeff_sign, int n_states,
+ struct tcq_rate_t *rd) {
(void)bwl;
#define Z -1
static const int8_t kShuf[2][32] = {
@@ -768,18 +1170,14 @@
ratez = _mm256_permute4x64_epi64(ratez, 0x88);
__m256i shuf1 = _mm256_lddqu_si256((__m256i *)kShuf[1]);
ratez = _mm256_shuffle_epi8(ratez, shuf1);
-#if MORESTATES
_mm256_storeu_si256((__m256i *)&rd->rate_zero[0], ratez);
-#else
- _mm_storeu_si128((__m128i *)&rd->rate_zero[0], _mm256_castsi256_si128(ratez));
-#endif
// Calc coeff_base rate.
int idx = AOMMIN(pq->qIdx - 1, 8);
__m128i c_zero = _mm_setzero_si128();
__m256i diag = _mm256_set1_epi8(diag_ctx);
base_ctx = _mm256_add_epi8(base_ctx, diag);
- for (int i = 0; i < TOTALSTATES / 4; i++) {
+ for (int i = 0; i < (n_states >> 2); i++) {
int ctx0 = _mm256_extract_epi8(base_ctx, 0);
int ctx1 = _mm256_extract_epi8(base_ctx, 1);
int ctx2 = _mm256_extract_epi8(base_ctx, 2);
@@ -817,7 +1215,7 @@
const bool is_dc_coeff = dc_2dtx || dc_hor || dc_ver;
#endif
if (is_dc_coeff) {
- for (int i = 0; i < TOTALSTATES; i++) {
+ for (int i = 0; i < n_states; i++) {
int a0 = i & 2 ? 1 : 0;
int a1 = a0 + 2;
int mid_cost0 = get_mid_cost_lf_dc(blk_pos, absLevel[a0], coeff_sign,
@@ -839,7 +1237,7 @@
rd->rate_eob[0] += eob_mid_cost0;
rd->rate_eob[1] += eob_mid_cost1;
} else if (idx > 4) {
- for (int i = 0; i < TOTALSTATES; i++) {
+ for (int i = 0; i < n_states; i++) {
int a0 = i & 2 ? 1 : 0;
int a1 = a0 + 2;
int mid_cost0 =
@@ -865,7 +1263,7 @@
const struct LV_MAP_COEFF_COST *txb_costs, const struct prequant_t *pq,
const struct tcq_coeff_ctx_t *coeff_ctx, int blk_pos, int bwl,
TX_CLASS tx_class, int diag_ctx, int eob_rate, int plane, int t_sign,
- int sign, struct tcq_rate_t *rd) {
+ int sign, int n_states, struct tcq_rate_t *rd) {
(void)bwl;
const int32_t(*cost_zero)[SIG_COEF_CONTEXTS] = txb_costs->base_cost_uv_zero;
const uint16_t(*cost_low_tbl)[SIG_COEF_CONTEXTS][DQ_CTXS][2] =
@@ -890,11 +1288,9 @@
__m256i ratez_0123 = _mm256_unpacklo_epi64(ratez_dq0, ratez_dq1);
_mm_storeu_si128((__m128i *)&rd->rate_zero[0],
_mm256_castsi256_si128(ratez_0123));
-#if MORESTATES
__m256i ratez_4567 = _mm256_unpackhi_epi64(ratez_dq0, ratez_dq1);
_mm_storeu_si128((__m128i *)&rd->rate_zero[4],
_mm256_castsi256_si128(ratez_4567));
-#endif
// Calc coeff_base rate.
int idx = AOMMIN(pq->qIdx - 1, 4);
@@ -903,7 +1299,7 @@
__m256i base_ctx = _mm256_slli_epi16(ctx16, 12);
base_ctx = _mm256_srli_epi16(base_ctx, 12);
base_ctx = _mm256_add_epi16(base_ctx, diag);
- for (int i = 0; i < TOTALSTATES / 4; i++) {
+ for (int i = 0; i < (n_states >> 2); i++) {
int ctx0 = _mm256_extract_epi16(base_ctx, 0);
int ctx1 = _mm256_extract_epi16(base_ctx, 1);
int ctx2 = _mm256_extract_epi16(base_ctx, 2);
@@ -931,7 +1327,7 @@
// Calc coeff mid and high range cost.
if (idx > 0 || plane) {
- for (int i = 0; i < TOTALSTATES; i++) {
+ for (int i = 0; i < n_states; i++) {
int a0 = i & 2 ? 1 : 0;
int a1 = a0 + 2;
int mid_cost0 = get_mid_cost_def(absLevel[a0], coeff_ctx->coef[i],
@@ -1085,17 +1481,18 @@
return dqv;
}
-int av1_find_best_path_avx2(const struct tcq_node_t *trellis,
+int av1_find_best_path_avx2(const struct tcq_node_t *trellis, int n_states_log2,
const int16_t *scan, const int32_t *dequant,
const qm_val_t *iqmatrix, const tran_low_t *tcoeff,
int first_scan_pos, int log_scale,
tran_low_t *qcoeff, tran_low_t *dqcoeff,
int *min_rate, int64_t *min_cost) {
// Select best trellis state.
+ int n_states = 1 << n_states_log2;
int64_t min_path_cost = INT64_MAX;
int trel_min_rate = INT32_MAX;
int prev_id = -2;
- for (int state = 0; state < TOTALSTATES; state++) {
+ for (int state = 0; state < n_states; state++) {
const tcq_node_t *decision = &trellis[state];
if (decision->rdCost < min_path_cost) {
prev_id = state;
@@ -1114,7 +1511,7 @@
int shift = QUANT_TABLE_BITS + log_scale;
for (; prev_id >= 0; scan_pos++) {
const int32_t *decision =
- (int32_t *)&trellis[scan_pos * TOTALSTATES + prev_id];
+ (int32_t *)&trellis[(scan_pos << n_states_log2) + prev_id];
__m128i info = _mm_loadu_si64(&decision[3]);
int blk_pos = scan[scan_pos];
__m128i sign = _mm_loadu_si64(&tcoeff[blk_pos]);
@@ -1149,7 +1546,8 @@
}
} else {
for (; prev_id >= 0; scan_pos++) {
- const tcq_node_t *decision = &trellis[scan_pos * TOTALSTATES + prev_id];
+ const tcq_node_t *decision =
+ &trellis[(scan_pos << n_states_log2) + prev_id];
prev_id = decision->prevId;
int abs_level = decision->absLevel;
int blk_pos = scan[scan_pos];