blob: 18f6d480f4e4a3eba699a20ce66d2b0e11cdce29 [file] [log] [blame]
/*
* Copyright (c) 2016, Alliance for Open Media. All rights reserved
*
* This source code is subject to the terms of the BSD 2 Clause License and
* the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
* was not distributed with this source code in the LICENSE file, you can
* obtain it at www.aomedia.org/license/software. If the Alliance for Open
* Media Patent License 1.0 was not distributed with this source code in the
* PATENTS file, you can obtain it at www.aomedia.org/license/patent.
*/
#include <assert.h>
#include "./aom_config.h"
#include "aom/aom_integer.h"
#include "aom_dsp/ans.h"
#include "aom_dsp/prob.h"
void aom_rans_build_cdf_from_pdf(const AnsP10 token_probs[], rans_lut cdf_tab) {
int i;
cdf_tab[0] = 0;
for (i = 1; cdf_tab[i - 1] < RANS_PRECISION; ++i) {
cdf_tab[i] = cdf_tab[i - 1] + token_probs[i - 1];
}
assert(cdf_tab[i - 1] == RANS_PRECISION);
}
static int find_largest(const AnsP10 *const pdf_tab, int num_syms) {
int largest_idx = -1;
int largest_p = -1;
int i;
for (i = 0; i < num_syms; ++i) {
int p = pdf_tab[i];
if (p > largest_p) {
largest_p = p;
largest_idx = i;
}
}
return largest_idx;
}
void aom_rans_merge_prob8_pdf(AnsP10 *const out_pdf, const AnsP8 node_prob,
const AnsP10 *const src_pdf, int in_syms) {
int i;
int adjustment = RANS_PRECISION;
const int round_fact = ANS_P8_PRECISION >> 1;
const AnsP8 p1 = ANS_P8_PRECISION - node_prob;
const int out_syms = in_syms + 1;
assert(src_pdf != out_pdf);
out_pdf[0] = node_prob << (10 - 8);
adjustment -= out_pdf[0];
for (i = 0; i < in_syms; ++i) {
int p = (p1 * src_pdf[i] + round_fact) >> ANS_P8_SHIFT;
p = AOMMIN(p, (int)RANS_PRECISION - in_syms);
p = AOMMAX(p, 1);
out_pdf[i + 1] = p;
adjustment -= p;
}
// Adjust probabilities so they sum to the total probability
if (adjustment > 0) {
i = find_largest(out_pdf, out_syms);
out_pdf[i] += adjustment;
} else {
while (adjustment < 0) {
i = find_largest(out_pdf, out_syms);
--out_pdf[i];
assert(out_pdf[i] > 0);
adjustment++;
}
}
}