blob: f5b0816dbf4dad6371f5940d782e8fa2a18b8fa6 [file] [log] [blame]
James Zern9097dcf2022-05-15 15:39:54 -07001#!/usr/bin/env python3
Alex Conversedacf45f2016-07-06 10:47:27 -07002##
3## Copyright (c) 2016, Alliance for Open Media. All rights reserved
4##
5## This source code is subject to the terms of the BSD 2 Clause License and
6## the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
7## was not distributed with this source code in the LICENSE file, you can
8## obtain it at www.aomedia.org/license/software. If the Alliance for Open
9## Media Patent License 1.0 was not distributed with this source code in the
10## PATENTS file, you can obtain it at www.aomedia.org/license/patent.
11##
12"""Generate the probability model for the constrained token set.
13
14Model obtained from a 2-sided zero-centered distribution derived
15from a Pareto distribution. The cdf of the distribution is:
16cdf(x) = 0.5 + 0.5 * sgn(x) * [1 - {alpha/(alpha + |x|)} ^ beta]
17
18For a given beta and a given probability of the 1-node, the alpha
19is first solved, and then the {alpha, beta} pair is used to generate
20the probabilities for the rest of the nodes.
21"""
22
23import heapq
24import sys
25import numpy as np
26import scipy.optimize
27import scipy.stats
28
29
30def cdf_spareto(x, xm, beta):
31 p = 1 - (xm / (np.abs(x) + xm))**beta
32 p = 0.5 + 0.5 * np.sign(x) * p
33 return p
34
35
36def get_spareto(p, beta):
37 cdf = cdf_spareto
38
39 def func(x):
40 return ((cdf(1.5, x, beta) - cdf(0.5, x, beta)) /
41 (1 - cdf(0.5, x, beta)) - p)**2
42
43 alpha = scipy.optimize.fminbound(func, 1e-12, 10000, xtol=1e-12)
44 parray = np.zeros(11)
45 parray[0] = 2 * (cdf(0.5, alpha, beta) - 0.5)
46 parray[1] = (2 * (cdf(1.5, alpha, beta) - cdf(0.5, alpha, beta)))
47 parray[2] = (2 * (cdf(2.5, alpha, beta) - cdf(1.5, alpha, beta)))
48 parray[3] = (2 * (cdf(3.5, alpha, beta) - cdf(2.5, alpha, beta)))
49 parray[4] = (2 * (cdf(4.5, alpha, beta) - cdf(3.5, alpha, beta)))
50 parray[5] = (2 * (cdf(6.5, alpha, beta) - cdf(4.5, alpha, beta)))
51 parray[6] = (2 * (cdf(10.5, alpha, beta) - cdf(6.5, alpha, beta)))
52 parray[7] = (2 * (cdf(18.5, alpha, beta) - cdf(10.5, alpha, beta)))
53 parray[8] = (2 * (cdf(34.5, alpha, beta) - cdf(18.5, alpha, beta)))
54 parray[9] = (2 * (cdf(66.5, alpha, beta) - cdf(34.5, alpha, beta)))
55 parray[10] = 2 * (1. - cdf(66.5, alpha, beta))
56 return parray
57
58
59def quantize_probs(p, save_first_bin, bits):
60 """Quantize probability precisely.
61
62 Quantize probabilities minimizing dH (Kullback-Leibler divergence)
63 approximated by: sum (p_i-q_i)^2/p_i.
64 References:
65 https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence
66 https://github.com/JarekDuda/AsymmetricNumeralSystemsToolkit
67 """
68 num_sym = p.size
69 p = np.clip(p, 1e-16, 1)
70 L = 2**bits
71 pL = p * L
72 ip = 1. / p # inverse probability
73 q = np.clip(np.round(pL), 1, L + 1 - num_sym)
74 quant_err = (pL - q)**2 * ip
75 sgn = np.sign(L - q.sum()) # direction of correction
76 if sgn != 0: # correction is needed
77 v = [] # heap of adjustment results (adjustment err, index) of each symbol
78 for i in range(1 if save_first_bin else 0, num_sym):
79 q_adj = q[i] + sgn
80 if q_adj > 0 and q_adj < L:
81 adj_err = (pL[i] - q_adj)**2 * ip[i] - quant_err[i]
82 heapq.heappush(v, (adj_err, i))
83 while q.sum() != L:
84 # apply lowest error adjustment
85 (adj_err, i) = heapq.heappop(v)
86 quant_err[i] += adj_err
87 q[i] += sgn
88 # calculate the cost of adjusting this symbol again
89 q_adj = q[i] + sgn
90 if q_adj > 0 and q_adj < L:
91 adj_err = (pL[i] - q_adj)**2 * ip[i] - quant_err[i]
92 heapq.heappush(v, (adj_err, i))
93 return q
94
95
Alex Conversea9598cd2017-02-03 14:18:05 -080096def get_quantized_spareto(p, beta, bits, first_token):
Alex Conversedacf45f2016-07-06 10:47:27 -070097 parray = get_spareto(p, beta)
98 parray = parray[1:] / (1 - parray[0])
Alex Conversea9598cd2017-02-03 14:18:05 -080099 # CONFIG_NEW_TOKENSET
100 if first_token > 1:
101 parray = parray[1:] / (1 - parray[0])
102 qarray = quantize_probs(parray, first_token == 1, bits)
Alex Conversedacf45f2016-07-06 10:47:27 -0700103 return qarray.astype(np.int)
104
105
Alex Conversea9598cd2017-02-03 14:18:05 -0800106def main(bits=15, first_token=1):
Alex Conversedacf45f2016-07-06 10:47:27 -0700107 beta = 8
108 for q in range(1, 256):
Alex Conversea9598cd2017-02-03 14:18:05 -0800109 parray = get_quantized_spareto(q / 256., beta, bits, first_token)
Alex Conversedacf45f2016-07-06 10:47:27 -0700110 assert parray.sum() == 2**bits
James Zern9097dcf2022-05-15 15:39:54 -0700111 print('{', ', '.join('%d' % i for i in parray), '},')
Alex Conversedacf45f2016-07-06 10:47:27 -0700112
113
114if __name__ == '__main__':
Alex Conversea9598cd2017-02-03 14:18:05 -0800115 if len(sys.argv) > 2:
116 main(int(sys.argv[1]), int(sys.argv[2]))
117 elif len(sys.argv) > 1:
Alex Conversedacf45f2016-07-06 10:47:27 -0700118 main(int(sys.argv[1]))
119 else:
120 main()