Optimize vp9_tree_probs_from_distribution
The previous implementation visited each node in the tree multiple times
because it used each symbol's encoding to revisit the branches taken and
increment its count. Instead, we can traverse the tree depth first and
calculate the probabilities and branch counts as we walk back up. The
complexity goes from somewhere between O(nlogn) and O(n^2) (depending on
how balanced the tree is) to O(n).
Only tested one clip (256kbps, CIF), saw 13% decoding perf improvement.
Note that this optimization should port trivially to VP8 as well. In VP8,
the decoder doesn't use this function, but it does routinely show up
on the profile for realtime encoding.
Change-Id: I4f2848e4f41dc9a7694f73f3e75034bce08d1b12
diff --git a/vp9/common/vp9_entropy.c b/vp9/common/vp9_entropy.c
index 1e3a7e1..bc69353 100644
--- a/vp9/common/vp9_entropy.c
+++ b/vp9/common/vp9_entropy.c
@@ -292,10 +292,9 @@
for (l = 0; l < PREV_COEF_CONTEXTS; ++l) {
if (l >= 3 && k == 0)
continue;
- vp9_tree_probs_from_distribution(MAX_ENTROPY_TOKENS,
- vp9_coef_encodings, vp9_coef_tree,
+ vp9_tree_probs_from_distribution(vp9_coef_tree,
coef_probs, branch_ct,
- coef_counts[i][j][k][l]);
+ coef_counts[i][j][k][l], 0);
for (t = 0; t < ENTROPY_NODES; ++t) {
count = branch_ct[t][0] + branch_ct[t][1];
count = count > count_sat ? count_sat : count;
diff --git a/vp9/common/vp9_entropymode.c b/vp9/common/vp9_entropymode.c
index 23b2abe..061c279 100644
--- a/vp9/common/vp9_entropymode.c
+++ b/vp9/common/vp9_entropymode.c
@@ -302,40 +302,32 @@
void vp9_init_mbmode_probs(VP9_COMMON *x) {
unsigned int bct [VP9_YMODES] [2]; /* num Ymodes > num UV modes */
- vp9_tree_probs_from_distribution(VP9_YMODES, vp9_ymode_encodings,
- vp9_ymode_tree, x->fc.ymode_prob,
- bct, y_mode_cts);
- vp9_tree_probs_from_distribution(VP9_I32X32_MODES, vp9_sb_ymode_encodings,
- vp9_sb_ymode_tree, x->fc.sb_ymode_prob,
- bct, y_mode_cts);
+ vp9_tree_probs_from_distribution(vp9_ymode_tree, x->fc.ymode_prob,
+ bct, y_mode_cts, 0);
+ vp9_tree_probs_from_distribution(vp9_sb_ymode_tree, x->fc.sb_ymode_prob,
+ bct, y_mode_cts, 0);
{
int i;
for (i = 0; i < 8; i++) {
- vp9_tree_probs_from_distribution(VP9_YMODES, vp9_kf_ymode_encodings,
- vp9_kf_ymode_tree, x->kf_ymode_prob[i],
- bct, kf_y_mode_cts[i]);
- vp9_tree_probs_from_distribution(VP9_I32X32_MODES,
- vp9_sb_kf_ymode_encodings,
- vp9_sb_kf_ymode_tree,
+ vp9_tree_probs_from_distribution(vp9_kf_ymode_tree, x->kf_ymode_prob[i],
+ bct, kf_y_mode_cts[i], 0);
+ vp9_tree_probs_from_distribution(vp9_sb_kf_ymode_tree,
x->sb_kf_ymode_prob[i], bct,
- kf_y_mode_cts[i]);
+ kf_y_mode_cts[i], 0);
}
}
{
int i;
for (i = 0; i < VP9_YMODES; i++) {
- vp9_tree_probs_from_distribution(VP9_UV_MODES, vp9_uv_mode_encodings,
- vp9_uv_mode_tree, x->kf_uv_mode_prob[i],
- bct, kf_uv_mode_cts[i]);
- vp9_tree_probs_from_distribution(VP9_UV_MODES, vp9_uv_mode_encodings,
- vp9_uv_mode_tree, x->fc.uv_mode_prob[i],
- bct, uv_mode_cts[i]);
+ vp9_tree_probs_from_distribution(vp9_uv_mode_tree, x->kf_uv_mode_prob[i],
+ bct, kf_uv_mode_cts[i], 0);
+ vp9_tree_probs_from_distribution(vp9_uv_mode_tree, x->fc.uv_mode_prob[i],
+ bct, uv_mode_cts[i], 0);
}
}
- vp9_tree_probs_from_distribution(VP9_I8X8_MODES, vp9_i8x8_mode_encodings,
- vp9_i8x8_mode_tree, x->fc.i8x8_mode_prob,
- bct, i8x8_mode_cts);
+ vp9_tree_probs_from_distribution(vp9_i8x8_mode_tree, x->fc.i8x8_mode_prob,
+ bct, i8x8_mode_cts, 0);
vpx_memcpy(x->fc.sub_mv_ref_prob, vp9_sub_mv_ref_prob2,
sizeof(vp9_sub_mv_ref_prob2));
@@ -355,8 +347,7 @@
vp9_prob p[VP9_NKF_BINTRAMODES - 1],
unsigned int branch_ct[VP9_NKF_BINTRAMODES - 1][2],
const unsigned int events[VP9_NKF_BINTRAMODES]) {
- vp9_tree_probs_from_distribution(VP9_NKF_BINTRAMODES, vp9_bmode_encodings,
- vp9_bmode_tree, p, branch_ct, events);
+ vp9_tree_probs_from_distribution(vp9_bmode_tree, p, branch_ct, events, 0);
}
void vp9_default_bmode_probs(vp9_prob p[VP9_NKF_BINTRAMODES - 1]) {
@@ -368,8 +359,7 @@
vp9_prob p[VP9_KF_BINTRAMODES - 1],
unsigned int branch_ct[VP9_KF_BINTRAMODES - 1][2],
const unsigned int events[VP9_KF_BINTRAMODES]) {
- vp9_tree_probs_from_distribution(VP9_KF_BINTRAMODES, vp9_kf_bmode_encodings,
- vp9_kf_bmode_tree, p, branch_ct, events);
+ vp9_tree_probs_from_distribution(vp9_kf_bmode_tree, p, branch_ct, events, 0);
}
void vp9_kf_default_bmode_probs(vp9_prob p[VP9_KF_BINTRAMODES]
@@ -538,17 +528,17 @@
#define MODE_COUNT_SAT 20
#define MODE_MAX_UPDATE_FACTOR 144
-static void update_mode_probs(int n_modes, struct vp9_token_struct *encoding,
+static void update_mode_probs(int n_modes,
const vp9_tree_index *tree, unsigned int *cnt,
- vp9_prob *pre_probs, vp9_prob *dst_probs) {
+ vp9_prob *pre_probs, vp9_prob *dst_probs,
+ unsigned int tok0_offset) {
#define MAX_PROBS 32
vp9_prob probs[MAX_PROBS];
unsigned int branch_ct[MAX_PROBS][2];
int t, count, factor;
assert(n_modes - 1 < MAX_PROBS);
- vp9_tree_probs_from_distribution(n_modes, encoding, tree, probs,
- branch_ct, cnt);
+ vp9_tree_probs_from_distribution(tree, probs, branch_ct, cnt, tok0_offset);
for (t = 0; t < n_modes - 1; ++t) {
count = branch_ct[t][0] + branch_ct[t][1];
count = count > MODE_COUNT_SAT ? MODE_COUNT_SAT : count;
@@ -604,31 +594,32 @@
#endif
#endif
- update_mode_probs(VP9_YMODES, vp9_ymode_encodings, vp9_ymode_tree,
+ update_mode_probs(VP9_YMODES, vp9_ymode_tree,
cm->fc.ymode_counts, cm->fc.pre_ymode_prob,
- cm->fc.ymode_prob);
- update_mode_probs(VP9_I32X32_MODES, vp9_sb_ymode_encodings, vp9_sb_ymode_tree,
+ cm->fc.ymode_prob, 0);
+ update_mode_probs(VP9_I32X32_MODES, vp9_sb_ymode_tree,
cm->fc.sb_ymode_counts, cm->fc.pre_sb_ymode_prob,
- cm->fc.sb_ymode_prob);
+ cm->fc.sb_ymode_prob, 0);
for (i = 0; i < VP9_YMODES; ++i) {
- update_mode_probs(VP9_UV_MODES, vp9_uv_mode_encodings, vp9_uv_mode_tree,
+ update_mode_probs(VP9_UV_MODES, vp9_uv_mode_tree,
cm->fc.uv_mode_counts[i], cm->fc.pre_uv_mode_prob[i],
- cm->fc.uv_mode_prob[i]);
+ cm->fc.uv_mode_prob[i], 0);
}
- update_mode_probs(VP9_NKF_BINTRAMODES, vp9_bmode_encodings, vp9_bmode_tree,
+ update_mode_probs(VP9_NKF_BINTRAMODES, vp9_bmode_tree,
cm->fc.bmode_counts, cm->fc.pre_bmode_prob,
- cm->fc.bmode_prob);
- update_mode_probs(VP9_I8X8_MODES, vp9_i8x8_mode_encodings,
+ cm->fc.bmode_prob, 0);
+ update_mode_probs(VP9_I8X8_MODES,
vp9_i8x8_mode_tree, cm->fc.i8x8_mode_counts,
- cm->fc.pre_i8x8_mode_prob, cm->fc.i8x8_mode_prob);
+ cm->fc.pre_i8x8_mode_prob, cm->fc.i8x8_mode_prob, 0);
for (i = 0; i < SUBMVREF_COUNT; ++i) {
- update_mode_probs(VP9_SUBMVREFS, vp9_sub_mv_ref_encoding_array,
+ update_mode_probs(VP9_SUBMVREFS,
vp9_sub_mv_ref_tree, cm->fc.sub_mv_ref_counts[i],
- cm->fc.pre_sub_mv_ref_prob[i], cm->fc.sub_mv_ref_prob[i]);
+ cm->fc.pre_sub_mv_ref_prob[i], cm->fc.sub_mv_ref_prob[i],
+ LEFT4X4);
}
- update_mode_probs(VP9_NUMMBSPLITS, vp9_mbsplit_encodings, vp9_mbsplit_tree,
+ update_mode_probs(VP9_NUMMBSPLITS, vp9_mbsplit_tree,
cm->fc.mbsplit_counts, cm->fc.pre_mbsplit_prob,
- cm->fc.mbsplit_prob);
+ cm->fc.mbsplit_prob, 0);
#if CONFIG_COMP_INTERINTRA_PRED
if (cm->use_interintra) {
int factor, interintra_prob, count;
diff --git a/vp9/common/vp9_entropymv.c b/vp9/common/vp9_entropymv.c
index 99e3c2e..ab87dfe 100644
--- a/vp9/common/vp9_entropymv.c
+++ b/vp9/common/vp9_entropymv.c
@@ -242,29 +242,23 @@
unsigned int (*branch_ct_hp)[2]) {
int i, j, k;
vp9_counts_process(NMVcount, usehp);
- vp9_tree_probs_from_distribution(MV_JOINTS,
- vp9_mv_joint_encodings,
- vp9_mv_joint_tree,
+ vp9_tree_probs_from_distribution(vp9_mv_joint_tree,
prob->joints,
branch_ct_joint,
- NMVcount->joints);
+ NMVcount->joints, 0);
for (i = 0; i < 2; ++i) {
prob->comps[i].sign = get_binary_prob(NMVcount->comps[i].sign[0],
NMVcount->comps[i].sign[1]);
branch_ct_sign[i][0] = NMVcount->comps[i].sign[0];
branch_ct_sign[i][1] = NMVcount->comps[i].sign[1];
- vp9_tree_probs_from_distribution(MV_CLASSES,
- vp9_mv_class_encodings,
- vp9_mv_class_tree,
+ vp9_tree_probs_from_distribution(vp9_mv_class_tree,
prob->comps[i].classes,
branch_ct_classes[i],
- NMVcount->comps[i].classes);
- vp9_tree_probs_from_distribution(CLASS0_SIZE,
- vp9_mv_class0_encodings,
- vp9_mv_class0_tree,
+ NMVcount->comps[i].classes, 0);
+ vp9_tree_probs_from_distribution(vp9_mv_class0_tree,
prob->comps[i].class0,
branch_ct_class0[i],
- NMVcount->comps[i].class0);
+ NMVcount->comps[i].class0, 0);
for (j = 0; j < MV_OFFSET_BITS; ++j) {
prob->comps[i].bits[j] = get_binary_prob(NMVcount->comps[i].bits[j][0],
NMVcount->comps[i].bits[j][1]);
@@ -274,19 +268,15 @@
}
for (i = 0; i < 2; ++i) {
for (k = 0; k < CLASS0_SIZE; ++k) {
- vp9_tree_probs_from_distribution(4,
- vp9_mv_fp_encodings,
- vp9_mv_fp_tree,
+ vp9_tree_probs_from_distribution(vp9_mv_fp_tree,
prob->comps[i].class0_fp[k],
branch_ct_class0_fp[i][k],
- NMVcount->comps[i].class0_fp[k]);
+ NMVcount->comps[i].class0_fp[k], 0);
}
- vp9_tree_probs_from_distribution(4,
- vp9_mv_fp_encodings,
- vp9_mv_fp_tree,
+ vp9_tree_probs_from_distribution(vp9_mv_fp_tree,
prob->comps[i].fp,
branch_ct_fp[i],
- NMVcount->comps[i].fp);
+ NMVcount->comps[i].fp, 0);
}
if (usehp) {
for (i = 0; i < 2; ++i) {
diff --git a/vp9/common/vp9_treecoder.c b/vp9/common/vp9_treecoder.c
index fbc8a38..6e25979 100644
--- a/vp9/common/vp9_treecoder.c
+++ b/vp9/common/vp9_treecoder.c
@@ -48,66 +48,37 @@
tree2tok(p - offset, t, 0, 0, 0);
}
-static void branch_counts(
- int n, /* n = size of alphabet */
- vp9_token tok [ /* n */ ],
- vp9_tree tree,
- unsigned int branch_ct [ /* n-1 */ ] [2],
- const unsigned int num_events[ /* n */ ]
-) {
- const int tree_len = n - 1;
- int t = 0;
+static unsigned int convert_distribution(unsigned int i,
+ vp9_tree tree,
+ vp9_prob probs[],
+ unsigned int branch_ct[][2],
+ const unsigned int num_events[],
+ unsigned int tok0_offset) {
+ unsigned int left, right;
-#if CONFIG_DEBUG
- assert(tree_len);
-#endif
-
- do {
- branch_ct[t][0] = branch_ct[t][1] = 0;
- } while (++t < tree_len);
-
- t = 0;
-
- do {
- int L = tok[t].Len;
- const int enc = tok[t].value;
- const unsigned int ct = num_events[t];
-
- vp9_tree_index i = 0;
-
- do {
- const int b = (enc >> --L) & 1;
- const int j = i >> 1;
-#if CONFIG_DEBUG
- assert(j < tree_len && 0 <= L);
-#endif
-
- branch_ct [j] [b] += ct;
- i = tree[ i + b];
- } while (i > 0);
-
-#if CONFIG_DEBUG
- assert(!L);
-#endif
- } while (++t < n);
-
+ if (tree[i] <= 0) {
+ left = num_events[-tree[i] - tok0_offset];
+ } else {
+ left = convert_distribution(tree[i], tree, probs, branch_ct,
+ num_events, tok0_offset);
+ }
+ if (tree[i + 1] <= 0) {
+ right = num_events[-tree[i + 1] - tok0_offset];
+ } else {
+ right = convert_distribution(tree[i + 1], tree, probs, branch_ct,
+ num_events, tok0_offset);
+ }
+ probs[i>>1] = get_binary_prob(left, right);
+ branch_ct[i>>1][0] = left;
+ branch_ct[i>>1][1] = right;
+ return left + right;
}
-
void vp9_tree_probs_from_distribution(
- int n, /* n = size of alphabet */
- vp9_token tok [ /* n */ ],
vp9_tree tree,
vp9_prob probs [ /* n-1 */ ],
unsigned int branch_ct [ /* n-1 */ ] [2],
- const unsigned int num_events[ /* n */ ]
-) {
- const int tree_len = n - 1;
- int t = 0;
-
- branch_counts(n, tok, tree, branch_ct, num_events);
-
- do {
- probs[t] = get_binary_prob(branch_ct[t][0], branch_ct[t][1]);
- } while (++t < tree_len);
+ const unsigned int num_events[ /* n */ ],
+ unsigned int tok0_offset) {
+ convert_distribution(0, tree, probs, branch_ct, num_events, tok0_offset);
}
diff --git a/vp9/common/vp9_treecoder.h b/vp9/common/vp9_treecoder.h
index f9f1d13..9297d52 100644
--- a/vp9/common/vp9_treecoder.h
+++ b/vp9/common/vp9_treecoder.h
@@ -47,12 +47,11 @@
taken for each node on the tree; this facilitiates decisions as to
probability updates. */
-void vp9_tree_probs_from_distribution(int n, /* n = size of alphabet */
- vp9_token tok[ /* n */ ],
- vp9_tree tree,
+void vp9_tree_probs_from_distribution(vp9_tree tree,
vp9_prob probs[ /* n - 1 */ ],
unsigned int branch_ct[ /* n - 1 */ ][2],
- const unsigned int num_events[ /* n */ ]);
+ const unsigned int num_events[ /* n */ ],
+ unsigned int tok0_offset);
static INLINE vp9_prob clip_prob(int p) {
return (p > 255) ? 255u : (p < 1) ? 1u : p;
diff --git a/vp9/encoder/vp9_bitstream.c b/vp9/encoder/vp9_bitstream.c
index b05da87..fcbd3a1 100644
--- a/vp9/encoder/vp9_bitstream.c
+++ b/vp9/encoder/vp9_bitstream.c
@@ -110,8 +110,8 @@
unsigned int new_b = 0, old_b = 0;
int i = 0;
- vp9_tree_probs_from_distribution(n--, tok, tree,
- Pnew, bct, num_events);
+ vp9_tree_probs_from_distribution(tree, Pnew, bct, num_events, 0);
+ n--;
do {
new_b += cost_branch(bct[i], Pnew[i]);
@@ -167,10 +167,9 @@
int i, j;
for (j = 0; j <= VP9_SWITCHABLE_FILTERS; ++j) {
vp9_tree_probs_from_distribution(
- VP9_SWITCHABLE_FILTERS,
- vp9_switchable_interp_encodings, vp9_switchable_interp_tree,
+ vp9_switchable_interp_tree,
pc->fc.switchable_interp_prob[j], branch_ct,
- cpi->switchable_interp_count[j]);
+ cpi->switchable_interp_count[j], 0);
for (i = 0; i < VP9_SWITCHABLE_FILTERS - 1; ++i) {
if (pc->fc.switchable_interp_prob[j][i] < 1)
pc->fc.switchable_interp_prob[j][i] = 1;
@@ -1189,11 +1188,10 @@
for (l = 0; l < PREV_COEF_CONTEXTS; ++l) {
if (l >= 3 && k == 0)
continue;
- vp9_tree_probs_from_distribution(MAX_ENTROPY_TOKENS,
- vp9_coef_encodings, vp9_coef_tree,
+ vp9_tree_probs_from_distribution(vp9_coef_tree,
coef_probs[i][j][k][l],
coef_branch_ct[i][j][k][l],
- coef_counts[i][j][k][l]);
+ coef_counts[i][j][k][l], 0);
#ifdef ENTROPY_STATS
if (!cpi->dummy_packing)
for (t = 0; t < MAX_ENTROPY_TOKENS; ++t)